diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-04-28 10:35:15 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-04-28 10:35:15 +0200 |
commit | c1ba1a0fec4aed430709030f98a3bdb90bfeea16 (patch) | |
tree | 3df18657e50a0313ed6defcda30e4474cb28a467 /src | |
parent | 7b15ed9ef455b6b66c6b376898a88aef5d6a9970 (diff) | |
download | golang-c1ba1a0fec4aed430709030f98a3bdb90bfeea16.tar.gz |
Imported Upstream version 2011.04.27upstream/2011.04.27
Diffstat (limited to 'src')
407 files changed, 10882 insertions, 3273 deletions
diff --git a/src/Make.cmd b/src/Make.cmd index 6f88e5cc2..e769e3072 100644 --- a/src/Make.cmd +++ b/src/Make.cmd @@ -6,6 +6,10 @@ ifeq ($(GOOS),windows) TARG:=$(TARG).exe endif +ifeq ($(TARGDIR),) +TARGDIR:=$(QUOTED_GOBIN) +endif + all: $(TARG) include $(QUOTED_GOROOT)/src/Make.common @@ -13,20 +17,20 @@ include $(QUOTED_GOROOT)/src/Make.common PREREQ+=$(patsubst %,%.make,$(DEPS)) $(TARG): _go_.$O - $(LD) -o $@ _go_.$O + $(LD) $(LDIMPORTS) -o $@ _go_.$O _go_.$O: $(GOFILES) $(PREREQ) - $(GC) -o $@ $(GOFILES) + $(GC) $(GCIMPORTS) -o $@ $(GOFILES) -install: $(QUOTED_GOBIN)/$(TARG) +install: $(TARGDIR)/$(TARG) -$(QUOTED_GOBIN)/$(TARG): $(TARG) - cp -f $(TARG) $(QUOTED_GOBIN) +$(TARGDIR)/$(TARG): $(TARG) + cp -f $(TARG) $(TARGDIR) CLEANFILES+=$(TARG) _test _testmain.go nuke: clean - rm -f $(QUOTED_GOBIN)/$(TARG) + rm -f $(TARGDIR)/$(TARG) # for gotest testpackage: _test/main.a @@ -40,7 +44,7 @@ _test/main.a: _gotest_.$O gopack grc $@ _gotest_.$O _gotest_.$O: $(GOFILES) $(GOTESTFILES) - $(GC) -o $@ $(GOFILES) $(GOTESTFILES) + $(GC) $(GCIMPORTS) -o $@ $(GOFILES) $(GOTESTFILES) importpath: echo main diff --git a/src/Make.common b/src/Make.common index 34d7016f4..0b27d07f9 100644 --- a/src/Make.common +++ b/src/Make.common @@ -6,13 +6,13 @@ clean: rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) install.clean: install - rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) + rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) || true test.clean: test - rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) + rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) || true testshort.clean: testshort - rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) + rm -rf *.o *.a *.[$(OS)] [$(OS)].out $(CLEANFILES) || true %.make: $(MAKE) -C $* install diff --git a/src/Make.pkg b/src/Make.pkg index 59ce56ac0..966bc61c7 100644 --- a/src/Make.pkg +++ b/src/Make.pkg @@ -31,7 +31,11 @@ endif pkgdir=$(QUOTED_GOROOT)/pkg/$(GOOS)_$(GOARCH) -INSTALLFILES+=$(pkgdir)/$(TARG).a +ifeq ($(TARGDIR),) +TARGDIR:=$(pkgdir) +endif + +INSTALLFILES+=$(TARGDIR)/$(TARG).a # The rest of the cgo rules are below, but these variable updates # must be done here so they apply to the main rules. @@ -46,7 +50,7 @@ GOFILES+=$(patsubst %.swig,_obj/%.go,$(patsubst %.swigcxx,%.swig,$(SWIGFILES))) OFILES+=$(patsubst %.swig,_obj/%_gc.$O,$(patsubst %.swigcxx,%.swig,$(SWIGFILES))) SWIG_PREFIX=$(subst /,-,$(TARG)) SWIG_SOS+=$(patsubst %.swig,_obj/$(SWIG_PREFIX)-%.so,$(patsubst %.swigcxx,%.swig,$(SWIGFILES))) -INSTALLFILES+=$(patsubst %.swig,$(pkgdir)/swig/$(SWIG_PREFIX)-%.so,$(patsubst %.swigcxx,%.swig,$(SWIGFILES))) +INSTALLFILES+=$(patsubst %.swig,$(TARGDIR)/swig/$(SWIG_PREFIX)-%.so,$(patsubst %.swigcxx,%.swig,$(SWIGFILES))) endif PREREQ+=$(patsubst %,%.make,$(DEPS)) @@ -67,22 +71,22 @@ bench: gotest -test.bench=. -test.run="Do not run tests" nuke: clean - rm -f $(pkgdir)/$(TARG).a + rm -f $(TARGDIR)/$(TARG).a testpackage-clean: rm -f _test/$(TARG).a install: $(INSTALLFILES) -$(pkgdir)/$(TARG).a: _obj/$(TARG).a - @test -d $(QUOTED_GOROOT)/pkg && mkdir -p $(pkgdir)/$(dir) +$(TARGDIR)/$(TARG).a: _obj/$(TARG).a + @test -d $(QUOTED_GOROOT)/pkg && mkdir -p $(TARGDIR)/$(dir) cp _obj/$(TARG).a "$@" _go_.$O: $(GOFILES) $(PREREQ) - $(GC) -o $@ $(GOFILES) + $(GC) $(GCIMPORTS) -o $@ $(GOFILES) _gotest_.$O: $(GOFILES) $(GOTESTFILES) $(PREREQ) - $(GC) -o $@ $(GOFILES) $(GOTESTFILES) + $(GC) $(GCIMPORTS) -o $@ $(GOFILES) $(GOTESTFILES) _obj/$(TARG).a: _go_.$O $(OFILES) @mkdir -p _obj/$(dir) @@ -222,13 +226,13 @@ _obj/$(SWIG_PREFIX)-%.so: _obj/%_wrap.o _obj/$(SWIG_PREFIX)-%.so: _obj/%_wrapcxx.o $(HOST_CXX) $(_CGO_CFLAGS_$(GOARCH)) -o $@ $^ $(SWIG_LDFLAGS) $(_CGO_LDFLAGS_$(GOOS)) $(_SWIG_LDFLAGS_$(GOOS)) -$(pkgdir)/swig/$(SWIG_PREFIX)-%.so: _obj/$(SWIG_PREFIX)-%.so - @test -d $(QUOTED_GOROOT)/pkg && mkdir -p $(pkgdir)/swig +$(TARGDIR)/swig/$(SWIG_PREFIX)-%.so: _obj/$(SWIG_PREFIX)-%.so + @test -d $(QUOTED_GOROOT)/pkg && mkdir -p $(TARGDIR)/swig cp $< "$@" all: $(SWIG_SOS) -SWIG_RPATH=-r $(pkgdir)/swig +SWIG_RPATH=-r $(TARGDIR)/swig endif diff --git a/src/all-qemu.bash b/src/all-qemu.bash index b2be15ac8..6d5cd6edd 100755 --- a/src/all-qemu.bash +++ b/src/all-qemu.bash @@ -6,7 +6,6 @@ # Run all.bash but exclude tests that depend on functionality # missing in QEMU's system call emulation. -export DISABLE_NET_TESTS=1 # no external network export NOTEST="" NOTEST="$NOTEST big" # xxx diff --git a/src/cmd/5c/swt.c b/src/cmd/5c/swt.c index d45aabc5e..431f04817 100644 --- a/src/cmd/5c/swt.c +++ b/src/cmd/5c/swt.c @@ -665,7 +665,9 @@ align(int32 i, Type *t, int op, int32 *maxalign) case Aarg2: /* width of a parameter */ o += t->width; - w = SZ_LONG; + w = t->width; + if(w > SZ_LONG) + w = SZ_LONG; break; case Aaut3: /* total align of automatic */ diff --git a/src/cmd/5g/cgen.c b/src/cmd/5g/cgen.c index 032409bae..4e5f7ebcd 100644 --- a/src/cmd/5g/cgen.c +++ b/src/cmd/5g/cgen.c @@ -43,6 +43,8 @@ cgen(Node *n, Node *res) } if(isfat(n->type)) { + if(n->type->width < 0) + fatal("forgot to compute width for %T", n->type); sgen(n, res, n->type->width); goto ret; } @@ -960,7 +962,7 @@ bgen(Node *n, int true, Prog *to) } // make simplest on right - if(nl->op == OLITERAL || nl->ullman < nr->ullman) { + if(nl->op == OLITERAL || (nl->ullman < UINF && nl->ullman < nr->ullman)) { a = brrev(a); r = nl; nl = nr; @@ -1071,18 +1073,18 @@ bgen(Node *n, int true, Prog *to) a = optoas(a, nr->type); if(nr->ullman >= UINF) { - regalloc(&n1, nr->type, N); - cgen(nr, &n1); + regalloc(&n1, nl->type, N); + cgen(nl, &n1); - tempname(&tmp, nr->type); + tempname(&tmp, nl->type); gmove(&n1, &tmp); regfree(&n1); - regalloc(&n1, nl->type, N); - cgen(nl, &n1); - regalloc(&n2, nr->type, N); - cgen(&tmp, &n2); + cgen(nr, &n2); + + regalloc(&n1, nl->type, N); + cgen(&tmp, &n1); gcmp(optoas(OCMP, nr->type), &n1, &n2); patch(gbranch(a, nr->type), to); diff --git a/src/cmd/5g/gg.h b/src/cmd/5g/gg.h index ce4575be9..78e6833b2 100644 --- a/src/cmd/5g/gg.h +++ b/src/cmd/5g/gg.h @@ -52,7 +52,7 @@ struct Prog EXTERN Biobuf* bout; EXTERN int32 dynloc; -EXTERN uchar reg[REGALLOC_FMAX]; +EXTERN uchar reg[REGALLOC_FMAX+1]; EXTERN int32 pcloc; // instruction counter EXTERN Strlit emptystring; extern char* anames[]; diff --git a/src/cmd/5g/peep.c b/src/cmd/5g/peep.c index ca12d70f2..6f36e12d4 100644 --- a/src/cmd/5g/peep.c +++ b/src/cmd/5g/peep.c @@ -1134,7 +1134,7 @@ copyu(Prog *p, Adr *v, Adr *s) if(v->type == D_REG) { if(v->reg <= REGEXT && v->reg > exregoffset) return 2; - if(v->reg == REGARG) + if(v->reg == (uchar)REGARG) return 2; } if(v->type == D_FREG) @@ -1152,7 +1152,7 @@ copyu(Prog *p, Adr *v, Adr *s) case ATEXT: /* funny */ if(v->type == D_REG) - if(v->reg == REGARG) + if(v->reg == (uchar)REGARG) return 3; return 0; } diff --git a/src/cmd/5l/l.h b/src/cmd/5l/l.h index cf5a9990b..f3c9d839d 100644 --- a/src/cmd/5l/l.h +++ b/src/cmd/5l/l.h @@ -156,6 +156,7 @@ struct Sym char* file; char* dynimpname; char* dynimplib; + char* dynimpvers; // STEXT Auto* autom; diff --git a/src/cmd/5l/obj.c b/src/cmd/5l/obj.c index f252f9fc5..c4a2bfc3f 100644 --- a/src/cmd/5l/obj.c +++ b/src/cmd/5l/obj.c @@ -317,7 +317,7 @@ zaddr(Biobuf *f, Adr *a, Sym *h[]) a->sym = h[c]; a->name = Bgetc(f); - if(a->reg < 0 || a->reg > NREG) { + if((schar)a->reg < 0 || a->reg > NREG) { print("register out of range %d\n", a->reg); Bputc(f, ALAST+1); return; /* force real diagnostic */ @@ -581,7 +581,7 @@ loop: diag("multiple initialization for %s: in both %s and %s", s->name, s->file, pn); errorexit(); } - savedata(s, p); + savedata(s, p, pn); unmal(p, sizeof *p); break; diff --git a/src/cmd/6g/cgen.c b/src/cmd/6g/cgen.c index 47f3374f5..75dc4fe13 100644 --- a/src/cmd/6g/cgen.c +++ b/src/cmd/6g/cgen.c @@ -47,6 +47,8 @@ cgen(Node *n, Node *res) } if(isfat(n->type)) { + if(n->type->width < 0) + fatal("forgot to compute width for %T", n->type); sgen(n, res, n->type->width); goto ret; } @@ -827,7 +829,7 @@ bgen(Node *n, int true, Prog *to) } // make simplest on right - if(nl->op == OLITERAL || nl->ullman < nr->ullman) { + if(nl->op == OLITERAL || (nl->ullman < nr->ullman && nl->ullman < UINF)) { a = brrev(a); r = nl; nl = nr; @@ -877,18 +879,18 @@ bgen(Node *n, int true, Prog *to) } if(nr->ullman >= UINF) { - regalloc(&n1, nr->type, N); - cgen(nr, &n1); + regalloc(&n1, nl->type, N); + cgen(nl, &n1); - tempname(&tmp, nr->type); + tempname(&tmp, nl->type); gmove(&n1, &tmp); regfree(&n1); - regalloc(&n1, nl->type, N); - cgen(nl, &n1); + regalloc(&n2, nr->type, N); + cgen(nr, &n2); - regalloc(&n2, nr->type, &n2); - cgen(&tmp, &n2); + regalloc(&n1, nl->type, N); + cgen(&tmp, &n1); goto cmp; } diff --git a/src/cmd/6g/reg.c b/src/cmd/6g/reg.c index 1e1d64c59..ed8bac3f0 100644 --- a/src/cmd/6g/reg.c +++ b/src/cmd/6g/reg.c @@ -1193,7 +1193,6 @@ void paint1(Reg *r, int bn) { Reg *r1; - Prog *p; int z; uint32 bb; @@ -1219,7 +1218,6 @@ paint1(Reg *r, int bn) } for(;;) { r->act.b[z] |= bb; - p = r->prog; if(r->use1.b[z] & bb) { change += CREF * r->loop; diff --git a/src/cmd/6l/asm.c b/src/cmd/6l/asm.c index ba2074fde..dda19e48d 100644 --- a/src/cmd/6l/asm.c +++ b/src/cmd/6l/asm.c @@ -95,6 +95,8 @@ enum { ElfStrStrtab, ElfStrRelaPlt, ElfStrPlt, + ElfStrGnuVersion, + ElfStrGnuVersionR, NElfStr }; @@ -436,6 +438,7 @@ adddynsym(Sym *s) s->dynid = nelfsym++; d = lookup(".dynsym", 0); + name = s->dynimpname; if(name == nil) name = s->name; @@ -586,6 +589,8 @@ doelf(void) elfstr[ElfStrRela] = addstring(shstrtab, ".rela"); elfstr[ElfStrRelaPlt] = addstring(shstrtab, ".rela.plt"); elfstr[ElfStrPlt] = addstring(shstrtab, ".plt"); + elfstr[ElfStrGnuVersion] = addstring(shstrtab, ".gnu.version"); + elfstr[ElfStrGnuVersionR] = addstring(shstrtab, ".gnu.version_r"); /* dynamic symbol table - first entry all zeros */ s = lookup(".dynsym", 0); @@ -629,6 +634,14 @@ doelf(void) s = lookup(".rela.plt", 0); s->reachable = 1; s->type = SELFDATA; + + s = lookup(".gnu.version", 0); + s->reachable = 1; + s->type = SELFDATA; + + s = lookup(".gnu.version_r", 0); + s->reachable = 1; + s->type = SELFDATA; /* define dynamic elf table */ s = lookup(".dynamic", 0); @@ -653,7 +666,8 @@ doelf(void) elfwritedynent(s, DT_PLTREL, DT_RELA); elfwritedynentsymsize(s, DT_PLTRELSZ, lookup(".rela.plt", 0)); elfwritedynentsym(s, DT_JMPREL, lookup(".rela.plt", 0)); - elfwritedynent(s, DT_NULL, 0); + + // Do not write DT_NULL. elfdynhash will finish it. } } @@ -681,7 +695,7 @@ asmb(void) { int32 magic; int a, dynsym; - vlong vl, va, startva, fo, w, symo, elfsymo, elfstro, elfsymsize, machlink; + vlong vl, startva, symo, elfsymo, elfstro, elfsymsize, machlink; ElfEhdr *eh; ElfPhdr *ph, *pph; ElfShdr *sh; @@ -735,8 +749,11 @@ asmb(void) /* index of elf text section; needed by asmelfsym, double-checked below */ /* !debug['d'] causes extra sections before the .text section */ elftextsh = 1; - if(!debug['d']) + if(!debug['d']) { elftextsh += 10; + if(elfverneed) + elftextsh += 2; + } break; case Hwindows: break; @@ -846,10 +863,7 @@ asmb(void) /* elf amd-64 */ eh = getElfEhdr(); - fo = HEADR; startva = INITTEXT - HEADR; - va = startva + fo; - w = segtext.filelen; /* This null SHdr must appear before all others */ sh = newElfShdr(elfstr[ElfStrEmpty]); @@ -923,6 +937,24 @@ asmb(void) sh->addralign = 1; shsym(sh, lookup(".dynstr", 0)); + if(elfverneed) { + sh = newElfShdr(elfstr[ElfStrGnuVersion]); + sh->type = SHT_GNU_VERSYM; + sh->flags = SHF_ALLOC; + sh->addralign = 2; + sh->link = dynsym; + sh->entsize = 2; + shsym(sh, lookup(".gnu.version", 0)); + + sh = newElfShdr(elfstr[ElfStrGnuVersionR]); + sh->type = SHT_GNU_VERNEED; + sh->flags = SHF_ALLOC; + sh->addralign = 8; + sh->info = elfverneed; + sh->link = dynsym+1; // dynstr + shsym(sh, lookup(".gnu.version_r", 0)); + } + sh = newElfShdr(elfstr[ElfStrRelaPlt]); sh->type = SHT_RELA; sh->flags = SHF_ALLOC; diff --git a/src/cmd/6l/l.h b/src/cmd/6l/l.h index 4fc13b94a..33ca51b2c 100644 --- a/src/cmd/6l/l.h +++ b/src/cmd/6l/l.h @@ -148,6 +148,7 @@ struct Sym char* file; char* dynimpname; char* dynimplib; + char* dynimpvers; // STEXT Auto* autom; diff --git a/src/cmd/6l/obj.c b/src/cmd/6l/obj.c index 6b43d2df4..d53814a74 100644 --- a/src/cmd/6l/obj.c +++ b/src/cmd/6l/obj.c @@ -356,6 +356,15 @@ zaddr(char *pn, Biobuf *f, Adr *a, Sym *h[]) return; } } + + switch(t) { + case D_FILE: + case D_FILE1: + case D_AUTO: + case D_PARAM: + if(s == S) + mangle(pn); + } u = mal(sizeof(*u)); u->link = curauto; @@ -380,7 +389,7 @@ ldobj1(Biobuf *f, char *pkg, int64 len, char *pn) vlong ipc; Prog *p; int v, o, r, skip, mode; - Sym *h[NSYM], *s, *di; + Sym *h[NSYM], *s; uint32 sig; char *name, *x; int ntext; @@ -391,7 +400,6 @@ ldobj1(Biobuf *f, char *pkg, int64 len, char *pn) lastp = nil; ntext = 0; eof = Boffset(f) + len; - di = S; src[0] = 0; newloop: @@ -559,7 +567,7 @@ loop: diag("multiple initialization for %s: in both %s and %s", s->name, s->file, pn); errorexit(); } - savedata(s, p); + savedata(s, p, pn); unmal(p, sizeof *p); goto loop; diff --git a/src/cmd/8g/cgen.c b/src/cmd/8g/cgen.c index 9c326e8ef..596824a6c 100644 --- a/src/cmd/8g/cgen.c +++ b/src/cmd/8g/cgen.c @@ -78,6 +78,8 @@ cgen(Node *n, Node *res) // structs etc get handled specially if(isfat(n->type)) { + if(n->type->width < 0) + fatal("forgot to compute width for %T", n->type); sgen(n, res, n->type->width); return; } @@ -898,7 +900,7 @@ bgen(Node *n, int true, Prog *to) } // make simplest on right - if(nl->op == OLITERAL || nl->ullman < nr->ullman) { + if(nl->op == OLITERAL || (nl->ullman < nr->ullman && nl->ullman < UINF)) { a = brrev(a); r = nl; nl = nr; @@ -1023,8 +1025,8 @@ bgen(Node *n, int true, Prog *to) if(nr->ullman >= UINF) { tempname(&n1, nl->type); tempname(&tmp, nr->type); - cgen(nr, &tmp); cgen(nl, &n1); + cgen(nr, &tmp); regalloc(&n2, nr->type, N); cgen(&tmp, &n2); goto cmp; diff --git a/src/cmd/8g/ggen.c b/src/cmd/8g/ggen.c index 8db552493..920725c3e 100644 --- a/src/cmd/8g/ggen.c +++ b/src/cmd/8g/ggen.c @@ -625,12 +625,8 @@ void cgen_div(int op, Node *nl, Node *nr, Node *res) { Node ax, dx, oldax, olddx; - int rax, rdx; Type *t; - rax = reg[D_AX]; - rdx = reg[D_DX]; - if(is64(nl->type)) fatal("cgen_div %T", nl->type); diff --git a/src/cmd/8l/asm.c b/src/cmd/8l/asm.c index b9bd0dae9..f28b8d904 100644 --- a/src/cmd/8l/asm.c +++ b/src/cmd/8l/asm.c @@ -91,6 +91,8 @@ enum { ElfStrStrtab, ElfStrRelPlt, ElfStrPlt, + ElfStrGnuVersion, + ElfStrGnuVersionR, NElfStr }; @@ -420,7 +422,7 @@ adddynsym(Sym *s) s->dynid = nelfsym++; d = lookup(".dynsym", 0); - + /* name */ name = s->dynimpname; if(name == nil) @@ -545,6 +547,8 @@ doelf(void) elfstr[ElfStrRel] = addstring(shstrtab, ".rel"); elfstr[ElfStrRelPlt] = addstring(shstrtab, ".rel.plt"); elfstr[ElfStrPlt] = addstring(shstrtab, ".plt"); + elfstr[ElfStrGnuVersion] = addstring(shstrtab, ".gnu.version"); + elfstr[ElfStrGnuVersionR] = addstring(shstrtab, ".gnu.version_r"); /* interpreter string */ s = lookup(".interp", 0); @@ -592,6 +596,14 @@ doelf(void) s = lookup(".rel.plt", 0); s->reachable = 1; s->type = SELFDATA; + + s = lookup(".gnu.version", 0); + s->reachable = 1; + s->type = SELFDATA; + + s = lookup(".gnu.version_r", 0); + s->reachable = 1; + s->type = SELFDATA; elfsetupplt(); @@ -617,7 +629,8 @@ doelf(void) elfwritedynent(s, DT_PLTREL, DT_REL); elfwritedynentsymsize(s, DT_PLTRELSZ, lookup(".rel.plt", 0)); elfwritedynentsym(s, DT_JMPREL, lookup(".rel.plt", 0)); - elfwritedynent(s, DT_NULL, 0); + + // Do not write DT_NULL. elfdynhash will finish it. } } @@ -681,8 +694,11 @@ asmb(void) /* index of elf text section; needed by asmelfsym, double-checked below */ /* !debug['d'] causes extra sections before the .text section */ elftextsh = 1; - if(!debug['d']) + if(!debug['d']) { elftextsh += 10; + if(elfverneed) + elftextsh += 2; + } } symsize = 0; @@ -966,6 +982,24 @@ asmb(void) sh->addralign = 1; shsym(sh, lookup(".dynstr", 0)); + if(elfverneed) { + sh = newElfShdr(elfstr[ElfStrGnuVersion]); + sh->type = SHT_GNU_VERSYM; + sh->flags = SHF_ALLOC; + sh->addralign = 2; + sh->link = dynsym; + sh->entsize = 2; + shsym(sh, lookup(".gnu.version", 0)); + + sh = newElfShdr(elfstr[ElfStrGnuVersionR]); + sh->type = SHT_GNU_VERNEED; + sh->flags = SHF_ALLOC; + sh->addralign = 4; + sh->info = elfverneed; + sh->link = dynsym+1; // dynstr + shsym(sh, lookup(".gnu.version_r", 0)); + } + sh = newElfShdr(elfstr[ElfStrRelPlt]); sh->type = SHT_REL; sh->flags = SHF_ALLOC; diff --git a/src/cmd/8l/l.h b/src/cmd/8l/l.h index ac0f3953f..8f39ef519 100644 --- a/src/cmd/8l/l.h +++ b/src/cmd/8l/l.h @@ -147,6 +147,7 @@ struct Sym char* file; char* dynimpname; char* dynimplib; + char* dynimpvers; // STEXT Auto* autom; diff --git a/src/cmd/8l/obj.c b/src/cmd/8l/obj.c index d505dc10e..2a38f7ef0 100644 --- a/src/cmd/8l/obj.c +++ b/src/cmd/8l/obj.c @@ -431,7 +431,7 @@ ldobj1(Biobuf *f, char *pkg, int64 len, char *pn) int32 ipc; Prog *p; int v, o, r, skip; - Sym *h[NSYM], *s, *di; + Sym *h[NSYM], *s; uint32 sig; int ntext; int32 eof; @@ -442,7 +442,6 @@ ldobj1(Biobuf *f, char *pkg, int64 len, char *pn) lastp = nil; ntext = 0; eof = Boffset(f) + len; - di = S; src[0] = 0; @@ -600,7 +599,7 @@ loop: diag("multiple initialization for %s: in both %s and %s", s->name, s->file, pn); errorexit(); } - savedata(s, p); + savedata(s, p, pn); unmal(p, sizeof *p); goto loop; diff --git a/src/cmd/8l/prof.c b/src/cmd/8l/prof.c index 4e95fad79..d99c5e408 100644 --- a/src/cmd/8l/prof.c +++ b/src/cmd/8l/prof.c @@ -36,7 +36,7 @@ void doprof1(void) { -#if 0 // TODO(rsc) +#ifdef NOTDEF // TODO(rsc) Sym *s; int32 n; Prog *p, *q; diff --git a/src/cmd/cc/dpchk.c b/src/cmd/cc/dpchk.c index d78a72a2b..0e51101f1 100644 --- a/src/cmd/cc/dpchk.c +++ b/src/cmd/cc/dpchk.c @@ -534,6 +534,32 @@ out: print("%s incomplete\n", s->name); } +Sym* +getimpsym(void) +{ + int c; + char *cp; + + c = getnsc(); + if(isspace(c) || c == '"') { + unget(c); + return S; + } + for(cp = symb;;) { + if(cp <= symb+NSYMB-4) + *cp++ = c; + c = getc(); + if(c > 0 && !isspace(c) && c != '"') + continue; + unget(c); + break; + } + *cp = 0; + if(cp > symb+NSYMB-4) + yyerror("symbol too large: %s", symb); + return lookup(); +} + void pragdynimport(void) { @@ -541,11 +567,11 @@ pragdynimport(void) char *path; Dynimp *f; - local = getsym(); + local = getimpsym(); if(local == nil) goto err; - remote = getsym(); + remote = getimpsym(); if(remote == nil) goto err; diff --git a/src/cmd/cc/macbody b/src/cmd/cc/macbody index 35740e985..ca8a54c0b 100644 --- a/src/cmd/cc/macbody +++ b/src/cmd/cc/macbody @@ -63,7 +63,7 @@ getsym(void) if(cp <= symb+NSYMB-4) *cp++ = c; c = getc(); - if(isalnum(c) || c == '_' || c >= 0x80 || c == '$') + if(isalnum(c) || c == '_' || c >= 0x80) continue; unget(c); break; diff --git a/src/cmd/cgo/main.go b/src/cmd/cgo/main.go index 00ffc4506..84aeccc21 100644 --- a/src/cmd/cgo/main.go +++ b/src/cmd/cgo/main.go @@ -20,7 +20,6 @@ import ( "os" "reflect" "strings" - "runtime" ) // A Package collects information about the package we're going to write. @@ -135,20 +134,7 @@ func main() { // instead of needing to make the linkers duplicate all the // specialized knowledge gcc has about where to look for imported // symbols and which ones to use. - syms, imports := dynimport(*dynobj) - if runtime.GOOS == "windows" { - for _, sym := range syms { - ss := strings.Split(sym, ":", -1) - fmt.Printf("#pragma dynimport %s %s %q\n", ss[0], ss[0], strings.ToLower(ss[1])) - } - return - } - for _, sym := range syms { - fmt.Printf("#pragma dynimport %s %s %q\n", sym, sym, "") - } - for _, p := range imports { - fmt.Printf("#pragma dynimport %s %s %q\n", "_", "_", p) - } + dynimport(*dynobj) return } diff --git a/src/cmd/cgo/out.go b/src/cmd/cgo/out.go index abf8c8bc2..bc031cc58 100644 --- a/src/cmd/cgo/out.go +++ b/src/cmd/cgo/out.go @@ -95,42 +95,63 @@ func (p *Package) writeDefs() { fc.Close() } -func dynimport(obj string) (syms, imports []string) { - var f interface { - ImportedLibraries() ([]string, os.Error) - ImportedSymbols() ([]string, os.Error) - } - var isMacho bool - var err1, err2, err3 os.Error - if f, err1 = elf.Open(obj); err1 != nil { - if f, err2 = pe.Open(obj); err2 != nil { - if f, err3 = macho.Open(obj); err3 != nil { - fatalf("cannot parse %s as ELF (%v) or PE (%v) or Mach-O (%v)", obj, err1, err2, err3) +func dynimport(obj string) { + if f, err := elf.Open(obj); err == nil { + sym, err := f.ImportedSymbols() + if err != nil { + fatalf("cannot load imported symbols from ELF file %s: %v", obj, err) + } + for _, s := range sym { + targ := s.Name + if s.Version != "" { + targ += "@" + s.Version } - isMacho = true + fmt.Printf("#pragma dynimport %s %s %q\n", s.Name, targ, s.Library) + } + lib, err := f.ImportedLibraries() + if err != nil { + fatalf("cannot load imported libraries from ELF file %s: %v", obj, err) + } + for _, l := range lib { + fmt.Printf("#pragma dynimport _ _ %q\n", l) } + return } - var err os.Error - syms, err = f.ImportedSymbols() - if err != nil { - fatalf("cannot load dynamic symbols: %v", err) - } - if isMacho { - // remove leading _ that OS X insists on - for i, s := range syms { - if len(s) >= 2 && s[0] == '_' { - syms[i] = s[1:] + if f, err := macho.Open(obj); err == nil { + sym, err := f.ImportedSymbols() + if err != nil { + fatalf("cannot load imported symbols from Mach-O file %s: %v", obj, err) + } + for _, s := range sym { + if len(s) > 0 && s[0] == '_' { + s = s[1:] } + fmt.Printf("#pragma dynimport %s %s %q\n", s, s, "") + } + lib, err := f.ImportedLibraries() + if err != nil { + fatalf("cannot load imported libraries from Mach-O file %s: %v", obj, err) } + for _, l := range lib { + fmt.Printf("#pragma dynimport _ _ %q\n", l) + } + return } - imports, err = f.ImportedLibraries() - if err != nil { - fatalf("cannot load dynamic imports: %v", err) + if f, err := pe.Open(obj); err == nil { + sym, err := f.ImportedSymbols() + if err != nil { + fatalf("cannot load imported symbols from PE file %s: v", obj, err) + } + for _, s := range sym { + ss := strings.Split(s, ":", -1) + fmt.Printf("#pragma dynimport %s %s %q\n", ss[0], ss[0], strings.ToLower(ss[1])) + } + return } - return + fatalf("cannot parse %s as ELF, Mach-O or PE", obj) } // Construct a gcc struct matching the 6c argument frame. @@ -312,8 +333,11 @@ func (p *Package) writeOutputFunc(fgcc *os.File, n *Name) { } fmt.Fprintf(fgcc, "\t%s *a = v;\n", ctype) fmt.Fprintf(fgcc, "\t") - if n.FuncType.Result != nil { + if t := n.FuncType.Result; t != nil { fmt.Fprintf(fgcc, "a->r = ") + if c := t.C.String(); c[len(c)-1] == '*' { + fmt.Fprintf(fgcc, "(const %s) ", t.C) + } } fmt.Fprintf(fgcc, "%s(", n.C) for i := range n.FuncType.Params { diff --git a/src/cmd/gc/align.c b/src/cmd/gc/align.c index a01e2ea46..a8454bf13 100644 --- a/src/cmd/gc/align.c +++ b/src/cmd/gc/align.c @@ -468,7 +468,7 @@ typeinit(void) okforadd[i] = 1; okforarith[i] = 1; okforconst[i] = 1; -// issimple[i] = 1; + issimple[i] = 1; } } @@ -530,7 +530,7 @@ typeinit(void) okfor[OCOM] = okforand; okfor[OMINUS] = okforarith; okfor[ONOT] = okforbool; - okfor[OPLUS] = okforadd; + okfor[OPLUS] = okforarith; // special okfor[OCAP] = okforcap; diff --git a/src/cmd/gc/cplx.c b/src/cmd/gc/cplx.c index 3ec9fe5a2..890cf7f10 100644 --- a/src/cmd/gc/cplx.c +++ b/src/cmd/gc/cplx.c @@ -12,6 +12,19 @@ static void minus(Node *nl, Node *res); #define CASE(a,b) (((a)<<16)|((b)<<0)) +static int +overlap(Node *f, Node *t) +{ + // check whether f and t could be overlapping stack references. + // not exact, because it's hard to check for the stack register + // in portable code. close enough: worst case we will allocate + // an extra temporary and the registerizer will clean it up. + return f->op == OINDREG && + t->op == OINDREG && + f->xoffset+f->type->width >= t->xoffset && + t->xoffset+t->type->width >= f->xoffset; +} + /* * generate: * res = n; @@ -43,9 +56,10 @@ complexmove(Node *f, Node *t) case CASE(TCOMPLEX64,TCOMPLEX128): case CASE(TCOMPLEX128,TCOMPLEX64): case CASE(TCOMPLEX128,TCOMPLEX128): - // complex to complex move/convert - // make from addable - if(!f->addable) { + // complex to complex move/convert. + // make f addable. + // also use temporary if possible stack overlap. + if(!f->addable || overlap(f, t)) { tempname(&n1, f->type); complexmove(f, &n1); f = &n1; diff --git a/src/cmd/gc/dcl.c b/src/cmd/gc/dcl.c index 3089a23b0..99af18d9f 100644 --- a/src/cmd/gc/dcl.c +++ b/src/cmd/gc/dcl.c @@ -560,6 +560,7 @@ funcargs(Node *nt) { Node *n; NodeList *l; + int gen; if(nt->op != OTFUNC) fatal("funcargs %O", nt->op); @@ -589,6 +590,7 @@ funcargs(Node *nt) } // declare the out arguments. + gen = 0; for(l=nt->rlist; l; l=l->next) { n = l->n; if(n->op != ODCLFIELD) @@ -596,6 +598,11 @@ funcargs(Node *nt) if(n->left != N) { n->left->op = ONAME; n->left->ntype = n->right; + if(isblank(n->left)) { + // Give it a name so we can assign to it during return. + snprint(namebuf, sizeof(namebuf), ".anon%d", gen++); + n->left->sym = lookup(namebuf); + } declare(n->left, PPARAMOUT); } } @@ -672,10 +679,10 @@ typedcl2(Type *pt, Type *t) ok: n = pt->nod; - *pt = *t; - pt->method = nil; + copytype(pt->nod, t); + // unzero nod pt->nod = n; - pt->sym = n->sym; + pt->sym->lastlineno = parserline(); declare(n, PEXTERN); @@ -697,12 +704,10 @@ stotype(NodeList *l, int et, Type **t, int funarg) Type *f, *t1, *t2, **t0; Strlit *note; int lno; - NodeList *init; Node *n, *left; char *what; t0 = t; - init = nil; lno = lineno; what = "field"; if(et == TINTER) @@ -1130,6 +1135,32 @@ addmethod(Sym *sf, Type *t, int local) pa = pa->type; f = methtype(pa); if(f == T) { + t = pa; + if(t != T) { + if(isptr[t->etype]) { + if(t->sym != S) { + yyerror("invalid receiver type %T (%T is a pointer type)", pa, t); + return; + } + t = t->type; + } + } + if(t != T) { + if(t->sym == S) { + yyerror("invalid receiver type %T (%T is an unnamed type)", pa, t); + return; + } + if(isptr[t->etype]) { + yyerror("invalid receiver type %T (%T is a pointer type)", pa, t); + return; + } + if(t->etype == TINTER) { + yyerror("invalid receiver type %T (%T is an interface type)", pa, t); + return; + } + } + // Should have picked off all the reasons above, + // but just in case, fall back to generic error. yyerror("invalid receiver type %T", pa); return; } diff --git a/src/cmd/gc/go.h b/src/cmd/gc/go.h index bb258a193..f58b76789 100644 --- a/src/cmd/gc/go.h +++ b/src/cmd/gc/go.h @@ -315,6 +315,7 @@ struct Pkg { char* name; Strlit* path; + Sym* pathsym; char* prefix; Pkg* link; char exported; // import line written in export data @@ -581,6 +582,7 @@ struct Io Biobuf* bin; int32 ilineno; int nlsemi; + int eofnl; int peekc; int peekc1; // second peekc for ... char* cp; // used for content when bin==nil @@ -1170,9 +1172,12 @@ Node* unsafenmagic(Node *n); */ Node* callnew(Type *t); Node* chanfn(char *name, int n, Type *t); +void copytype(Node *n, Type *t); +void defertypecopy(Node *n, Type *t); Node* mkcall(char *name, Type *t, NodeList **init, ...); Node* mkcall1(Node *fn, Type *t, NodeList **init, ...); void queuemethod(Node *n); +void resumetypecopy(void); int vmatch1(Node *l, Node *r); void walk(Node *fn); Node* walkdef(Node *n); diff --git a/src/cmd/gc/go.y b/src/cmd/gc/go.y index 89899ae1e..7adfd002a 100644 --- a/src/cmd/gc/go.y +++ b/src/cmd/gc/go.y @@ -1853,6 +1853,10 @@ hidden_interfacedcl: { $$ = nod(ODCLFIELD, newname($1), typenod(functype(fakethis(), $3, $5))); } +| hidden_importsym '(' ohidden_funarg_list ')' ohidden_funres + { + $$ = nod(ODCLFIELD, newname($1), typenod(functype(fakethis(), $3, $5))); + } ohidden_funres: { diff --git a/src/cmd/gc/lex.c b/src/cmd/gc/lex.c index bfd96274e..04dd0d5b9 100644 --- a/src/cmd/gc/lex.c +++ b/src/cmd/gc/lex.c @@ -249,6 +249,7 @@ main(int argc, char *argv[]) for(l=xtop; l; l=l->next) if(l->n->op == ODCL || l->n->op == OAS) typecheck(&l->n, Etop); + resumetypecopy(); resumecheckwidth(); for(l=xtop; l; l=l->next) if(l->n->op == ODCLFUNC) @@ -1310,7 +1311,7 @@ getc(void) lexlineno++; return c; } - + if(curio.bin == nil) { c = *curio.cp & 0xff; if(c != 0) @@ -1325,8 +1326,11 @@ getc(void) break; } case EOF: - return EOF; - + // insert \n at EOF + if(curio.eofnl) + return EOF; + curio.eofnl = 1; + c = '\n'; case '\n': if(pushedio.bin == nil) lexlineno++; diff --git a/src/cmd/gc/print.c b/src/cmd/gc/print.c index fee37f6d0..e03a14080 100644 --- a/src/cmd/gc/print.c +++ b/src/cmd/gc/print.c @@ -242,6 +242,17 @@ exprfmt(Fmt *f, Node *n, int prec) exprfmt(f, n->right, 0); break; + case OAS2: + case OAS2DOTTYPE: + case OAS2FUNC: + case OAS2MAPR: + case OAS2MAPW: + case OAS2RECV: + exprlistfmt(f, n->list); + fmtprint(f, " = "); + exprlistfmt(f, n->rlist); + break; + case OADD: case OANDAND: case OANDNOT: diff --git a/src/cmd/gc/reflect.c b/src/cmd/gc/reflect.c index b98e820c6..810787d30 100644 --- a/src/cmd/gc/reflect.c +++ b/src/cmd/gc/reflect.c @@ -137,7 +137,6 @@ methodfunc(Type *f, Type *receiver) static Sig* methods(Type *t) { - int o; Type *f, *mt, *it, *this; Sig *a, *b; Sym *method; @@ -157,7 +156,6 @@ methods(Type *t) // make list of methods for t, // generating code if necessary. a = nil; - o = 0; oldlist = nil; for(f=mt->xmethod; f; f=f->down) { if(f->type->etype != TFUNC) @@ -184,6 +182,11 @@ methods(Type *t) a = b; a->name = method->name; + if(!exportname(method->name)) { + if(method->pkg == nil) + fatal("methods: missing package"); + a->pkg = method->pkg; + } a->isym = methodsym(method, it, 1); a->tsym = methodsym(method, t, 0); a->type = methodfunc(f->type, t); @@ -240,14 +243,12 @@ static Sig* imethods(Type *t) { Sig *a, *all, *last; - int o; Type *f; Sym *method, *isym; Prog *oldlist; all = nil; last = nil; - o = 0; oldlist = nil; for(f=t->type; f; f=f->down) { if(f->etype != TFIELD) @@ -257,8 +258,11 @@ imethods(Type *t) method = f->sym; a = mal(sizeof(*a)); a->name = method->name; - if(!exportname(method->name)) + if(!exportname(method->name)) { + if(method->pkg == nil) + fatal("imethods: missing package"); a->pkg = method->pkg; + } a->mtype = f->type; a->offset = 0; a->type = methodfunc(f->type, nil); @@ -301,26 +305,6 @@ imethods(Type *t) return all; } -static int -dgopkgpath(Sym *s, int ot, Pkg *pkg) -{ - if(pkg == nil) - return dgostringptr(s, ot, nil); - - // Emit reference to go.importpath.""., which 6l will - // rewrite using the correct import path. Every package - // that imports this one directly defines the symbol. - if(pkg == localpkg) { - static Sym *ns; - - if(ns == nil) - ns = pkglookup("importpath.\"\".", mkpkg(strlit("go"))); - return dsymptr(s, ot, ns, 0); - } - - return dgostringptr(s, ot, pkg->name); -} - static void dimportpath(Pkg *p) { @@ -328,6 +312,9 @@ dimportpath(Pkg *p) char *nam; Node *n; + if(p->pathsym != S) + return; + if(gopkg == nil) { gopkg = mkpkg(strlit("go")); gopkg->name = "go"; @@ -339,11 +326,33 @@ dimportpath(Pkg *p) free(nam); n->class = PEXTERN; n->xoffset = 0; + p->pathsym = n->sym; gdatastring(n, p->path); ggloblsym(n->sym, types[TSTRING]->width, 1); } +static int +dgopkgpath(Sym *s, int ot, Pkg *pkg) +{ + if(pkg == nil) + return dgostringptr(s, ot, nil); + + // Emit reference to go.importpath.""., which 6l will + // rewrite using the correct import path. Every package + // that imports this one directly defines the symbol. + if(pkg == localpkg) { + static Sym *ns; + + if(ns == nil) + ns = pkglookup("importpath.\"\".", mkpkg(strlit("go"))); + return dsymptr(s, ot, ns, 0); + } + + dimportpath(pkg); + return dsymptr(s, ot, pkg->pathsym, 0); +} + /* * uncommonType * ../../pkg/runtime/type.go:/uncommonType @@ -694,7 +703,7 @@ dtypesym(Type *t) int ot, xt, n, isddd, dupok; Sym *s, *s1, *s2; Sig *a, *m; - Type *t1, *tbase; + Type *t1, *tbase, *t2; if(isideal(t)) fatal("dtypesym %T", t); @@ -731,15 +740,25 @@ ok: break; case TARRAY: - // ../../pkg/runtime/type.go:/ArrayType - s1 = dtypesym(t->type); - ot = dcommontype(s, ot, t); - xt = ot - 2*widthptr; - ot = dsymptr(s, ot, s1, 0); - if(t->bound < 0) - ot = duintptr(s, ot, -1); - else + if(t->bound >= 0) { + // ../../pkg/runtime/type.go:/ArrayType + s1 = dtypesym(t->type); + t2 = typ(TARRAY); + t2->type = t->type; + t2->bound = -1; // slice + s2 = dtypesym(t2); + ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; + ot = dsymptr(s, ot, s1, 0); + ot = dsymptr(s, ot, s2, 0); ot = duintptr(s, ot, t->bound); + } else { + // ../../pkg/runtime/type.go:/SliceType + s1 = dtypesym(t->type); + ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; + ot = dsymptr(s, ot, s1, 0); + } break; case TCHAN: diff --git a/src/cmd/gc/subr.c b/src/cmd/gc/subr.c index 2098794a7..bb2505694 100644 --- a/src/cmd/gc/subr.c +++ b/src/cmd/gc/subr.c @@ -488,7 +488,7 @@ algtype(Type *t) { int a; - if(issimple[t->etype] || isptr[t->etype] || iscomplex[t->etype] || + if(issimple[t->etype] || isptr[t->etype] || t->etype == TCHAN || t->etype == TFUNC || t->etype == TMAP) { if(t->width == widthptr) a = AMEMWORD; @@ -660,12 +660,10 @@ nodbool(int b) Type* aindex(Node *b, Type *t) { - NodeList *init; Type *r; int bound; bound = -1; // open bound - init = nil; typecheck(&b, Erv); if(b != nil) { switch(consttype(b)) { @@ -1266,7 +1264,12 @@ Tpretty(Fmt *fp, Type *t) case TINTER: fmtprint(fp, "interface {"); for(t1=t->type; t1!=T; t1=t1->down) { - fmtprint(fp, " %hS%hhT", t1->sym, t1->type); + fmtprint(fp, " "); + if(exportname(t1->sym->name)) + fmtprint(fp, "%hS", t1->sym); + else + fmtprint(fp, "%S", t1->sym); + fmtprint(fp, "%hhT", t1->type); if(t1->down) fmtprint(fp, ";"); } @@ -1728,17 +1731,13 @@ isideal(Type *t) Type* methtype(Type *t) { - int ptr; - if(t == T) return T; // strip away pointer if it's there - ptr = 0; if(isptr[t->etype]) { if(t->sym != S) return T; - ptr = 1; t = t->type; if(t == T) return T; @@ -1929,13 +1928,14 @@ assignop(Type *src, Type *dst, char **why) } return 0; } + if(isptrto(dst, TINTER)) { + if(why != nil) + *why = smprint(":\n\t%T is pointer to interface, not interface", dst); + return 0; + } if(src->etype == TINTER && dst->etype != TBLANK) { - if(why != nil) { - if(isptrto(dst, TINTER)) - *why = smprint(":\n\t%T is interface, not pointer to interface", src); - else - *why = ": need type assertion"; - } + if(why != nil) + *why = ": need type assertion"; return 0; } diff --git a/src/cmd/gc/swt.c b/src/cmd/gc/swt.c index fbc9c4903..6e8436c3c 100644 --- a/src/cmd/gc/swt.c +++ b/src/cmd/gc/swt.c @@ -250,7 +250,7 @@ newlabel(void) static void casebody(Node *sw, Node *typeswvar) { - Node *os, *oc, *n, *c, *last; + Node *n, *c, *last; Node *def; NodeList *cas, *stat, *l, *lc; Node *go, *br; @@ -263,8 +263,6 @@ casebody(Node *sw, Node *typeswvar) cas = nil; // cases stat = nil; // statements def = N; // defaults - os = N; // last statement - oc = N; // last case br = nod(OBREAK, N, N); for(l=sw->list; l; l=l->next) { diff --git a/src/cmd/gc/typecheck.c b/src/cmd/gc/typecheck.c index 1cc5abd5c..c48bf7a29 100644 --- a/src/cmd/gc/typecheck.c +++ b/src/cmd/gc/typecheck.c @@ -31,6 +31,7 @@ static void checkassign(Node*); static void checkassignlist(NodeList*); static void stringtoarraylit(Node**); static Node* resolve(Node*); +static Type* getforwtype(Node*); /* * resolve ONONAME to definition, if any. @@ -56,7 +57,7 @@ typechecklist(NodeList *l, int top) typecheck(&l->n, top); } -static char* typekind[] = { +static char* _typekind[] = { [TINT] = "int", [TUINT] = "uint", [TINT8] = "int8", @@ -82,8 +83,22 @@ static char* typekind[] = { [TMAP] = "map", [TARRAY] = "array", [TFUNC] = "func", + [TNIL] = "nil", + [TIDEAL] = "ideal number", }; +static char* +typekind(int et) +{ + static char buf[50]; + char *s; + + if(0 <= et && et < nelem(_typekind) && (s=_typekind[et]) != nil) + return s; + snprint(buf, sizeof buf, "etype=%d", et); + return buf; +} + /* * type check node *np. * replaces *np with a new pointer in some cases. @@ -96,7 +111,7 @@ typecheck(Node **np, int top) Node *n, *l, *r; NodeList *args; int lno, ok, ntop; - Type *t, *tp, *missing, *have; + Type *t, *tp, *ft, *missing, *have; Sym *sym; Val v; char *why; @@ -139,6 +154,11 @@ typecheck(Node **np, int top) yyerror("use of builtin %S not in function call", n->sym); goto error; } + + // a dance to handle forward-declared recursive pointer types. + if(n->op == OTYPE && (ft = getforwtype(n->ntype)) != T) + defertypecopy(n, ft); + walkdef(n); n->realtype = n->type; if(n->op == ONONAME) @@ -406,7 +426,7 @@ reswitch: } if(!okfor[op][et]) { notokfor: - yyerror("invalid operation: %#N (operator %#O not defined on %s)", n, op, typekind[et]); + yyerror("invalid operation: %#N (operator %#O not defined on %s)", n, op, typekind(et)); goto error; } // okfor allows any array == array; @@ -992,9 +1012,13 @@ reswitch: defaultlit(&n->right, T); // copy([]byte, string) - if(isslice(n->left->type) && n->left->type->type == types[TUINT8] && n->right->type->etype == TSTRING) - goto ret; - + if(isslice(n->left->type) && n->right->type->etype == TSTRING) { + if (n->left->type->type ==types[TUINT8]) + goto ret; + yyerror("arguments to copy have different element types: %lT and string", n->left->type); + goto error; + } + if(!isslice(n->left->type) || !isslice(n->right->type)) { if(!isslice(n->left->type) && !isslice(n->right->type)) yyerror("arguments to copy must be slices; have %lT, %lT", n->left->type, n->right->type); @@ -2452,3 +2476,24 @@ stringtoarraylit(Node **np) typecheck(&nn, Erv); *np = nn; } + +static Type* +getforwtype(Node *n) +{ + Node *f1, *f2; + + for(f1=f2=n; ; n=n->ntype) { + if((n = resolve(n)) == N || n->op != OTYPE) + return T; + + if(n->type != T && n->type->etype == TFORW) + return n->type; + + // Check for ntype cycle. + if((f2 = resolve(f2)) != N && (f1 = resolve(f2->ntype)) != N) { + f2 = resolve(f1->ntype); + if(f1 == n || f2 == n) + return T; + } + } +} diff --git a/src/cmd/gc/walk.c b/src/cmd/gc/walk.c index b8c6842e0..278eef414 100644 --- a/src/cmd/gc/walk.c +++ b/src/cmd/gc/walk.c @@ -119,6 +119,62 @@ domethod(Node *n) checkwidth(n->type); } +typedef struct NodeTypeList NodeTypeList; +struct NodeTypeList { + Node *n; + Type *t; + NodeTypeList *next; +}; + +static NodeTypeList *dntq; +static NodeTypeList *dntend; + +void +defertypecopy(Node *n, Type *t) +{ + NodeTypeList *ntl; + + if(n == N || t == T) + return; + + ntl = mal(sizeof *ntl); + ntl->n = n; + ntl->t = t; + ntl->next = nil; + + if(dntq == nil) + dntq = ntl; + else + dntend->next = ntl; + + dntend = ntl; +} + +void +resumetypecopy(void) +{ + NodeTypeList *l; + + for(l=dntq; l; l=l->next) + copytype(l->n, l->t); +} + +void +copytype(Node *n, Type *t) +{ + *n->type = *t; + + t = n->type; + t->sym = n->sym; + t->local = n->local; + t->vargen = n->vargen; + t->siggen = 0; + t->method = nil; + t->nod = N; + t->printed = 0; + t->deferwidth = 0; +} + static void walkdeftype(Node *n) { @@ -141,20 +197,14 @@ walkdeftype(Node *n) goto ret; } - // copy new type and clear fields - // that don't come along maplineno = n->type->maplineno; embedlineno = n->type->embedlineno; - *n->type = *t; - t = n->type; - t->sym = n->sym; - t->local = n->local; - t->vargen = n->vargen; - t->siggen = 0; - t->method = nil; - t->nod = N; - t->printed = 0; - t->deferwidth = 0; + + // copy new type and clear fields + // that don't come along. + // anything zeroed here must be zeroed in + // typedcl2 too. + copytype(n, t); // double-check use of type as map key. if(maplineno) { @@ -197,7 +247,6 @@ Node* walkdef(Node *n) { int lno; - NodeList *init; Node *e; Type *t; NodeList *l; @@ -236,7 +285,6 @@ walkdef(Node *n) if(n->type != T || n->sym == S) // builtin or no name goto ret; - init = nil; switch(n->op) { default: fatal("walkdef %O", n->op); @@ -380,14 +428,13 @@ walkstmt(Node **np) { NodeList *init; NodeList *ll, *rl; - int cl, lno; + int cl; Node *n, *f; n = *np; if(n == N) return; - lno = lineno; setlineno(n); switch(n->op) { @@ -1359,7 +1406,7 @@ walkexpr(Node **np, NodeList **init) case OSTRARRAYBYTE: // stringtoslicebyte(string) []byte; - n = mkcall("stringtoslicebyte", n->type, init, n->left); + n = mkcall("stringtoslicebyte", n->type, init, conv(n->left, types[TSTRING])); goto ret; case OSTRARRAYRUNE: @@ -1788,7 +1835,7 @@ walkprint(Node *nn, NodeList **init, int defer) on = syslook("printiface", 1); argtype(on, n->type); // any-1 } - } else if(isptr[et] || et == TCHAN || et == TMAP || et == TFUNC) { + } else if(isptr[et] || et == TCHAN || et == TMAP || et == TFUNC || et == TUNSAFEPTR) { if(defer) { fmtprint(&fmt, "%%p"); } else { diff --git a/src/cmd/gofix/doc.go b/src/cmd/gofix/doc.go index 902fe76f2..a9790e685 100644 --- a/src/cmd/gofix/doc.go +++ b/src/cmd/gofix/doc.go @@ -18,6 +18,9 @@ If the named path is a directory, gofix rewrites all .go files in that directory tree. When gofix rewrites a file, it prints a line to standard error giving the name of the file and the rewrite applied. +If the -diff flag is set, no files are rewritten. Instead gofix prints +the differences a rewrite would introduce. + The -r flag restricts the set of rewrites considered to those in the named list. By default gofix considers all known rewrites. Gofix's rewrites are idempotent, so that it is safe to apply gofix to updated @@ -29,6 +32,5 @@ to see them, run gofix -?. Gofix does not make backup copies of the files that it edits. Instead, use a version control system's ``diff'' functionality to inspect the changes that gofix makes before committing them. - */ package documentation diff --git a/src/cmd/gofix/reflect.go b/src/cmd/gofix/reflect.go index 74ddb398f..3c8becaef 100644 --- a/src/cmd/gofix/reflect.go +++ b/src/cmd/gofix/reflect.go @@ -21,6 +21,7 @@ var reflectFix = fix{ `Adapt code to new reflect API. http://codereview.appspot.com/4281055 +http://codereview.appspot.com/4433066 `, } @@ -279,6 +280,23 @@ func reflectFn(f *ast.File) bool { fixed = true }) + // Rewrite + // reflect.Typeof -> reflect.TypeOf, + walk(f, func(n interface{}) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return + } + if isTopName(sel.X, "reflect") && sel.Sel.Name == "Typeof" { + sel.Sel.Name = "TypeOf" + fixed = true + } + if isTopName(sel.X, "reflect") && sel.Sel.Name == "NewValue" { + sel.Sel.Name = "ValueOf" + fixed = true + } + }) + return fixed } diff --git a/src/cmd/gofix/testdata/reflect.asn1.go.out b/src/cmd/gofix/testdata/reflect.asn1.go.out index 902635939..f5716f273 100644 --- a/src/cmd/gofix/testdata/reflect.asn1.go.out +++ b/src/cmd/gofix/testdata/reflect.asn1.go.out @@ -418,13 +418,13 @@ func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type } var ( - bitStringType = reflect.Typeof(BitString{}) - objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) - enumeratedType = reflect.Typeof(Enumerated(0)) - flagType = reflect.Typeof(Flag(false)) - timeType = reflect.Typeof(&time.Time{}) - rawValueType = reflect.Typeof(RawValue{}) - rawContentsType = reflect.Typeof(RawContent(nil)) + bitStringType = reflect.TypeOf(BitString{}) + objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) + enumeratedType = reflect.TypeOf(Enumerated(0)) + flagType = reflect.TypeOf(Flag(false)) + timeType = reflect.TypeOf(&time.Time{}) + rawValueType = reflect.TypeOf(RawValue{}) + rawContentsType = reflect.TypeOf(RawContent(nil)) ) // invalidLength returns true iff offset + length > sliceLength, or if the @@ -461,7 +461,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} offset += t.length - v.Set(reflect.NewValue(result)) + v.Set(reflect.ValueOf(result)) return } @@ -506,7 +506,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam return } if result != nil { - ifaceValue.Set(reflect.NewValue(result)) + ifaceValue.Set(reflect.ValueOf(result)) } return } @@ -609,7 +609,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam sliceValue := v sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(newSlice), len(newSlice))) if err1 == nil { - reflect.Copy(sliceValue, reflect.NewValue(newSlice)) + reflect.Copy(sliceValue, reflect.ValueOf(newSlice)) } err = err1 return @@ -617,7 +617,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam structValue := v bs, err1 := parseBitString(innerBytes) if err1 == nil { - structValue.Set(reflect.NewValue(bs)) + structValue.Set(reflect.ValueOf(bs)) } err = err1 return @@ -631,7 +631,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam time, err1 = parseGeneralizedTime(innerBytes) } if err1 == nil { - ptrValue.Set(reflect.NewValue(time)) + ptrValue.Set(reflect.ValueOf(time)) } err = err1 return @@ -679,7 +679,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam if structType.NumField() > 0 && structType.Field(0).Type == rawContentsType { bytes := bytes[initOffset:offset] - val.Field(0).Set(reflect.NewValue(RawContent(bytes))) + val.Field(0).Set(reflect.ValueOf(RawContent(bytes))) } innerOffset := 0 @@ -701,7 +701,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam sliceType := fieldType if sliceType.Elem().Kind() == reflect.Uint8 { val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) - reflect.Copy(val, reflect.NewValue(innerBytes)) + reflect.Copy(val, reflect.ValueOf(innerBytes)) return } newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) @@ -806,7 +806,7 @@ func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { // UnmarshalWithParams allows field parameters to be specified for the // top-level element. The form of the params is the same as the field tags. func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { - v := reflect.NewValue(val).Elem() + v := reflect.ValueOf(val).Elem() offset, err := parseField(v, b, 0, parseFieldParameters(params)) if err != nil { return nil, err diff --git a/src/cmd/gofix/testdata/reflect.datafmt.go.out b/src/cmd/gofix/testdata/reflect.datafmt.go.out index 6d816fc2d..bd7f5fd31 100644 --- a/src/cmd/gofix/testdata/reflect.datafmt.go.out +++ b/src/cmd/gofix/testdata/reflect.datafmt.go.out @@ -671,7 +671,7 @@ func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { go func() { for _, v := range args { - fld := reflect.NewValue(v) + fld := reflect.ValueOf(v) if !fld.IsValid() { errors <- os.NewError("nil argument") return diff --git a/src/cmd/gofix/testdata/reflect.decode.go.out b/src/cmd/gofix/testdata/reflect.decode.go.out index a5fd33912..feeb7b867 100644 --- a/src/cmd/gofix/testdata/reflect.decode.go.out +++ b/src/cmd/gofix/testdata/reflect.decode.go.out @@ -122,11 +122,11 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) { } }() - rv := reflect.NewValue(v) + rv := reflect.ValueOf(v) pv := rv if pv.Kind() != reflect.Ptr || pv.IsNil() { - return &InvalidUnmarshalError{reflect.Typeof(v)} + return &InvalidUnmarshalError{reflect.TypeOf(v)} } d.scan.reset() @@ -314,7 +314,7 @@ func (d *decodeState) array(v reflect.Value) { iv := v ok := iv.Kind() == reflect.Interface if ok { - iv.Set(reflect.NewValue(d.arrayInterface())) + iv.Set(reflect.ValueOf(d.arrayInterface())) return } @@ -410,7 +410,7 @@ func (d *decodeState) object(v reflect.Value) { // Decoding into nil interface? Switch to non-reflect code. iv := v if iv.Kind() == reflect.Interface { - iv.Set(reflect.NewValue(d.objectInterface())) + iv.Set(reflect.ValueOf(d.objectInterface())) return } @@ -423,7 +423,7 @@ func (d *decodeState) object(v reflect.Value) { case reflect.Map: // map must have string type t := v.Type() - if t.Key() != reflect.Typeof("") { + if t.Key() != reflect.TypeOf("") { d.saveError(&UnmarshalTypeError{"object", v.Type()}) break } @@ -514,7 +514,7 @@ func (d *decodeState) object(v reflect.Value) { // Write value back to map; // if using struct, subv points into struct already. if mv.IsValid() { - mv.SetMapIndex(reflect.NewValue(key), subv) + mv.SetMapIndex(reflect.ValueOf(key), subv) } // Next token must be , or }. @@ -570,7 +570,7 @@ func (d *decodeState) literal(v reflect.Value) { case reflect.Bool: v.SetBool(value) case reflect.Interface: - v.Set(reflect.NewValue(value)) + v.Set(reflect.ValueOf(value)) } case '"': // string @@ -592,11 +592,11 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(err) break } - v.Set(reflect.NewValue(b[0:n])) + v.Set(reflect.ValueOf(b[0:n])) case reflect.String: v.SetString(string(s)) case reflect.Interface: - v.Set(reflect.NewValue(string(s))) + v.Set(reflect.ValueOf(string(s))) } default: // number @@ -613,7 +613,7 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(reflect.NewValue(n)) + v.Set(reflect.ValueOf(n)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.Atoi64(s) @@ -767,7 +767,7 @@ func (d *decodeState) literalInterface() interface{} { } n, err := strconv.Atof64(string(item)) if err != nil { - d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)}) + d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.TypeOf(0.0)}) } return n } diff --git a/src/cmd/gofix/testdata/reflect.decoder.go.out b/src/cmd/gofix/testdata/reflect.decoder.go.out index a631c27a2..170eedb05 100644 --- a/src/cmd/gofix/testdata/reflect.decoder.go.out +++ b/src/cmd/gofix/testdata/reflect.decoder.go.out @@ -50,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) { // Type: wire := new(wireType) - dec.decodeValue(tWireType, reflect.NewValue(wire)) + dec.decodeValue(tWireType, reflect.ValueOf(wire)) if dec.err != nil { return } @@ -161,7 +161,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error { if e == nil { return dec.DecodeValue(reflect.Value{}) } - value := reflect.NewValue(e) + value := reflect.ValueOf(e) // If e represents a value as opposed to a pointer, the answer won't // get back to the caller. Make sure it's a pointer. if value.Type().Kind() != reflect.Ptr { diff --git a/src/cmd/gofix/testdata/reflect.dnsmsg.go.out b/src/cmd/gofix/testdata/reflect.dnsmsg.go.out index 546e713a0..12e4c34c3 100644 --- a/src/cmd/gofix/testdata/reflect.dnsmsg.go.out +++ b/src/cmd/gofix/testdata/reflect.dnsmsg.go.out @@ -430,7 +430,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) if off+n > len(msg) { return len(msg), false } - reflect.Copy(reflect.NewValue(msg[off:off+n]), fv) + reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv) off += n case reflect.String: // There are multiple string encodings. @@ -460,7 +460,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) } func structValue(any interface{}) reflect.Value { - return reflect.NewValue(any).Elem() + return reflect.ValueOf(any).Elem() } func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { @@ -508,7 +508,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo if off+n > len(msg) { return len(msg), false } - reflect.Copy(fv, reflect.NewValue(msg[off:off+n])) + reflect.Copy(fv, reflect.ValueOf(msg[off:off+n])) off += n case reflect.String: var s string diff --git a/src/cmd/gofix/testdata/reflect.encode.go.out b/src/cmd/gofix/testdata/reflect.encode.go.out index 8c79a27d4..9a13a75ab 100644 --- a/src/cmd/gofix/testdata/reflect.encode.go.out +++ b/src/cmd/gofix/testdata/reflect.encode.go.out @@ -172,7 +172,7 @@ func (e *encodeState) marshal(v interface{}) (err os.Error) { err = r.(os.Error) } }() - e.reflectValue(reflect.NewValue(v)) + e.reflectValue(reflect.ValueOf(v)) return nil } @@ -180,7 +180,7 @@ func (e *encodeState) error(err os.Error) { panic(err) } -var byteSliceType = reflect.Typeof([]byte(nil)) +var byteSliceType = reflect.TypeOf([]byte(nil)) func (e *encodeState) reflectValue(v reflect.Value) { if !v.IsValid() { diff --git a/src/cmd/gofix/testdata/reflect.encoder.go.out b/src/cmd/gofix/testdata/reflect.encoder.go.out index 928f3b244..781ef6504 100644 --- a/src/cmd/gofix/testdata/reflect.encoder.go.out +++ b/src/cmd/gofix/testdata/reflect.encoder.go.out @@ -97,7 +97,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp // Id: state.encodeInt(-int64(info.id)) // Type: - enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo) enc.writeMessage(w, state.b) if enc.err != nil { return @@ -162,7 +162,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ // Encode transmits the data item represented by the empty interface value, // guaranteeing that all necessary type information has been transmitted first. func (enc *Encoder) Encode(e interface{}) os.Error { - return enc.EncodeValue(reflect.NewValue(e)) + return enc.EncodeValue(reflect.ValueOf(e)) } // sendTypeDescriptor makes sure the remote side knows about this type. diff --git a/src/cmd/gofix/testdata/reflect.export.go.out b/src/cmd/gofix/testdata/reflect.export.go.out index 2209f04e8..486a812e2 100644 --- a/src/cmd/gofix/testdata/reflect.export.go.out +++ b/src/cmd/gofix/testdata/reflect.export.go.out @@ -111,9 +111,9 @@ func (client *expClient) getChan(hdr *header, dir Dir) *netChan { // data arrives from the client. func (client *expClient) run() { hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) req := new(request) - reqValue := reflect.NewValue(req) + reqValue := reflect.ValueOf(req) error := new(error) for { *hdr = header{} @@ -341,7 +341,7 @@ func (exp *Exporter) Sync(timeout int64) os.Error { } func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { - chanType := reflect.Typeof(chT) + chanType := reflect.TypeOf(chT) if chanType.Kind() != reflect.Chan { return reflect.Value{}, os.ErrorString("not a channel") } @@ -359,7 +359,7 @@ func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { return reflect.Value{}, os.ErrorString("to import/export with Recv, must provide chan<-") } } - return reflect.NewValue(chT), nil + return reflect.ValueOf(chT), nil } // Export exports a channel of a given type and specified direction. The diff --git a/src/cmd/gofix/testdata/reflect.print.go.out b/src/cmd/gofix/testdata/reflect.print.go.out index e3dc775cf..079948cca 100644 --- a/src/cmd/gofix/testdata/reflect.print.go.out +++ b/src/cmd/gofix/testdata/reflect.print.go.out @@ -260,7 +260,7 @@ func getField(v reflect.Value, i int) reflect.Value { val := v.Field(i) if i := val; i.Kind() == reflect.Interface { if inter := i.Interface(); inter != nil { - return reflect.NewValue(inter) + return reflect.ValueOf(inter) } } return val @@ -284,7 +284,7 @@ func (p *pp) unknownType(v interface{}) { return } p.buf.WriteByte('?') - p.buf.WriteString(reflect.Typeof(v).String()) + p.buf.WriteString(reflect.TypeOf(v).String()) p.buf.WriteByte('?') } @@ -296,7 +296,7 @@ func (p *pp) badVerb(verb int, val interface{}) { if val == nil { p.buf.Write(nilAngleBytes) } else { - p.buf.WriteString(reflect.Typeof(val).String()) + p.buf.WriteString(reflect.TypeOf(val).String()) p.add('=') p.printField(val, 'v', false, false, 0) } @@ -525,7 +525,7 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } if goSyntax { p.add('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.add(')') p.add('(') if u == 0 { @@ -540,10 +540,10 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } var ( - intBits = reflect.Typeof(0).Bits() - floatBits = reflect.Typeof(0.0).Bits() - complexBits = reflect.Typeof(1i).Bits() - uintptrBits = reflect.Typeof(uintptr(0)).Bits() + intBits = reflect.TypeOf(0).Bits() + floatBits = reflect.TypeOf(0.0).Bits() + complexBits = reflect.TypeOf(1i).Bits() + uintptrBits = reflect.TypeOf(uintptr(0)).Bits() ) func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) { @@ -560,10 +560,10 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth // %T (the value's type) and %p (its address) are special; we always do them first. switch verb { case 'T': - p.printField(reflect.Typeof(field).String(), 's', false, false, 0) + p.printField(reflect.TypeOf(field).String(), 's', false, false, 0) return false case 'p': - p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax) + p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax) return false } // Is it a Formatter? @@ -651,7 +651,7 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth } // Need to use reflection - value := reflect.NewValue(field) + value := reflect.ValueOf(field) BigSwitch: switch f := value; f.Kind() { @@ -702,7 +702,7 @@ BigSwitch: } case reflect.Struct: if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) } p.add('{') v := f @@ -728,7 +728,7 @@ BigSwitch: value := f.Elem() if !value.IsValid() { if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.Write(nilParenBytes) } else { p.buf.Write(nilAngleBytes) @@ -754,7 +754,7 @@ BigSwitch: return verb == 's' } if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('{') } else { p.buf.WriteByte('[') @@ -792,7 +792,7 @@ BigSwitch: } if goSyntax { p.buf.WriteByte('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte(')') p.buf.WriteByte('(') if v == 0 { @@ -913,7 +913,7 @@ func (p *pp) doPrintf(format string, a []interface{}) { for ; fieldnum < len(a); fieldnum++ { field := a[fieldnum] if field != nil { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('=') } p.printField(field, 'v', false, false, 0) @@ -932,7 +932,7 @@ func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) { // always add spaces if we're doing println field := a[fieldnum] if fieldnum > 0 { - isString := field != nil && reflect.Typeof(field).Kind() == reflect.String + isString := field != nil && reflect.TypeOf(field).Kind() == reflect.String if addspace || !isString && !prevString { p.buf.WriteByte(' ') } diff --git a/src/cmd/gofix/testdata/reflect.quick.go.out b/src/cmd/gofix/testdata/reflect.quick.go.out index 152dbad32..c62305b83 100644 --- a/src/cmd/gofix/testdata/reflect.quick.go.out +++ b/src/cmd/gofix/testdata/reflect.quick.go.out @@ -59,39 +59,39 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { switch concrete := t; concrete.Kind() { case reflect.Bool: - return reflect.NewValue(rand.Int()&1 == 0), true + return reflect.ValueOf(rand.Int()&1 == 0), true case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Complex64, reflect.Complex128: switch t.Kind() { case reflect.Float32: - return reflect.NewValue(randFloat32(rand)), true + return reflect.ValueOf(randFloat32(rand)), true case reflect.Float64: - return reflect.NewValue(randFloat64(rand)), true + return reflect.ValueOf(randFloat64(rand)), true case reflect.Complex64: - return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true + return reflect.ValueOf(complex(randFloat32(rand), randFloat32(rand))), true case reflect.Complex128: - return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true + return reflect.ValueOf(complex(randFloat64(rand), randFloat64(rand))), true case reflect.Int16: - return reflect.NewValue(int16(randInt64(rand))), true + return reflect.ValueOf(int16(randInt64(rand))), true case reflect.Int32: - return reflect.NewValue(int32(randInt64(rand))), true + return reflect.ValueOf(int32(randInt64(rand))), true case reflect.Int64: - return reflect.NewValue(randInt64(rand)), true + return reflect.ValueOf(randInt64(rand)), true case reflect.Int8: - return reflect.NewValue(int8(randInt64(rand))), true + return reflect.ValueOf(int8(randInt64(rand))), true case reflect.Int: - return reflect.NewValue(int(randInt64(rand))), true + return reflect.ValueOf(int(randInt64(rand))), true case reflect.Uint16: - return reflect.NewValue(uint16(randInt64(rand))), true + return reflect.ValueOf(uint16(randInt64(rand))), true case reflect.Uint32: - return reflect.NewValue(uint32(randInt64(rand))), true + return reflect.ValueOf(uint32(randInt64(rand))), true case reflect.Uint64: - return reflect.NewValue(uint64(randInt64(rand))), true + return reflect.ValueOf(uint64(randInt64(rand))), true case reflect.Uint8: - return reflect.NewValue(uint8(randInt64(rand))), true + return reflect.ValueOf(uint8(randInt64(rand))), true case reflect.Uint: - return reflect.NewValue(uint(randInt64(rand))), true + return reflect.ValueOf(uint(randInt64(rand))), true case reflect.Uintptr: - return reflect.NewValue(uintptr(randInt64(rand))), true + return reflect.ValueOf(uintptr(randInt64(rand))), true } case reflect.Map: numElems := rand.Intn(complexSize) @@ -130,7 +130,7 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { for i := 0; i < numChars; i++ { codePoints[i] = rand.Intn(0x10ffff) } - return reflect.NewValue(string(codePoints)), true + return reflect.ValueOf(string(codePoints)), true case reflect.Struct: s := reflect.Zero(t) for i := 0; i < s.NumField(); i++ { @@ -339,7 +339,7 @@ func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand } func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) { - v = reflect.NewValue(f) + v = reflect.ValueOf(f) ok = v.Kind() == reflect.Func if !ok { return diff --git a/src/cmd/gofix/testdata/reflect.read.go.out b/src/cmd/gofix/testdata/reflect.read.go.out index a3ddb9d4c..554b2a61b 100644 --- a/src/cmd/gofix/testdata/reflect.read.go.out +++ b/src/cmd/gofix/testdata/reflect.read.go.out @@ -139,7 +139,7 @@ import ( // to a freshly allocated value and then mapping the element to that value. // func Unmarshal(r io.Reader, val interface{}) os.Error { - v := reflect.NewValue(val) + v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } @@ -176,7 +176,7 @@ func (e *TagPathError) String() string { // Passing a nil start element indicates that Unmarshal should // read the token stream to find the start element. func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error { - v := reflect.NewValue(val) + v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } @@ -280,7 +280,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { case reflect.Struct: if _, ok := v.Interface().(Name); ok { - v.Set(reflect.NewValue(start.Name)) + v.Set(reflect.ValueOf(start.Name)) break } @@ -316,7 +316,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { if _, ok := v.Interface().(Name); !ok { return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name") } - v.Set(reflect.NewValue(start.Name)) + v.Set(reflect.ValueOf(start.Name)) } // Assign attributes. @@ -508,21 +508,21 @@ Loop: case reflect.String: t.SetString(string(data)) case reflect.Slice: - t.Set(reflect.NewValue(data)) + t.Set(reflect.ValueOf(data)) } switch t := saveComment; t.Kind() { case reflect.String: t.SetString(string(comment)) case reflect.Slice: - t.Set(reflect.NewValue(comment)) + t.Set(reflect.ValueOf(comment)) } switch t := saveXML; t.Kind() { case reflect.String: t.SetString(string(saveXMLData)) case reflect.Slice: - t.Set(reflect.NewValue(saveXMLData)) + t.Set(reflect.ValueOf(saveXMLData)) } return nil diff --git a/src/cmd/gofix/testdata/reflect.scan.go.out b/src/cmd/gofix/testdata/reflect.scan.go.out index b1b3975e2..42bc52c92 100644 --- a/src/cmd/gofix/testdata/reflect.scan.go.out +++ b/src/cmd/gofix/testdata/reflect.scan.go.out @@ -423,7 +423,7 @@ func (s *ss) token(skipSpace bool, f func(int) bool) []byte { // typeError indicates that the type of the operand did not match the format func (s *ss) typeError(field interface{}, expected string) { - s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String()) + s.errorString("expected field of type pointer to " + expected + "; found " + reflect.TypeOf(field).String()) } var complexError = os.ErrorString("syntax error scanning complex number") @@ -908,7 +908,7 @@ func (s *ss) scanOne(verb int, field interface{}) { // If we scanned to bytes, the slice would point at the buffer. *v = []byte(s.convertString(verb)) default: - val := reflect.NewValue(v) + val := reflect.ValueOf(v) ptr := val if ptr.Kind() != reflect.Ptr { s.errorString("Scan: type not a pointer: " + val.Type().String()) diff --git a/src/cmd/gofix/testdata/reflect.script.go.out b/src/cmd/gofix/testdata/reflect.script.go.out index b18018497..bc5a6a41d 100644 --- a/src/cmd/gofix/testdata/reflect.script.go.out +++ b/src/cmd/gofix/testdata/reflect.script.go.out @@ -134,19 +134,19 @@ type empty struct { } func newEmptyInterface(e empty) reflect.Value { - return reflect.NewValue(e).Field(0) + return reflect.ValueOf(e).Field(0) } func (s Send) send() { // With reflect.ChanValue.Send, we must match the types exactly. So, if // s.Channel is a chan interface{} we convert s.Value to an interface{} // first. - c := reflect.NewValue(s.Channel) + c := reflect.ValueOf(s.Channel) var v reflect.Value if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 { v = newEmptyInterface(empty{s.Value}) } else { - v = reflect.NewValue(s.Value) + v = reflect.ValueOf(s.Value) } c.Send(v) } @@ -162,7 +162,7 @@ func (s Close) getSend() sendAction { return s } func (s Close) getChannel() interface{} { return s.Channel } -func (s Close) send() { reflect.NewValue(s.Channel).Close() } +func (s Close) send() { reflect.ValueOf(s.Channel).Close() } // A ReceivedUnexpected error results if no active Events match a value // received from a channel. @@ -278,7 +278,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { continue } c := event.action.getChannel() - if reflect.NewValue(c).Kind() != reflect.Chan { + if reflect.ValueOf(c).Kind() != reflect.Chan { return nil, SetupError("one of the channel values is not a channel") } @@ -303,7 +303,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { // channel repeatedly, wrapping them up as either a channelRecv or // channelClosed structure, and forwards them to the multiplex channel. func recvValues(multiplex chan<- interface{}, channel interface{}) { - c := reflect.NewValue(channel) + c := reflect.ValueOf(channel) for { v, ok := c.Recv() diff --git a/src/cmd/gofix/testdata/reflect.template.go.out b/src/cmd/gofix/testdata/reflect.template.go.out index 28872dbee..c36288455 100644 --- a/src/cmd/gofix/testdata/reflect.template.go.out +++ b/src/cmd/gofix/testdata/reflect.template.go.out @@ -646,7 +646,7 @@ func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value } return av.FieldByName(name) case reflect.Map: - if v := av.MapIndex(reflect.NewValue(name)); v.IsValid() { + if v := av.MapIndex(reflect.ValueOf(name)); v.IsValid() { return v } return reflect.Zero(typ.Elem()) @@ -797,7 +797,7 @@ func (t *Template) executeElement(i int, st *state) int { return elem.end } e := t.elems.At(i) - t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e) + t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.ValueOf(e).Interface(), e) return 0 } @@ -980,7 +980,7 @@ func (t *Template) ParseFile(filename string) (err os.Error) { // generating output to wr. func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) { // Extract the driver data. - val := reflect.NewValue(data) + val := reflect.ValueOf(data) defer checkError(&err) t.p = 0 t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr}) diff --git a/src/cmd/gofix/testdata/reflect.type.go.out b/src/cmd/gofix/testdata/reflect.type.go.out index 8fd174841..a39b074fe 100644 --- a/src/cmd/gofix/testdata/reflect.type.go.out +++ b/src/cmd/gofix/testdata/reflect.type.go.out @@ -243,18 +243,18 @@ var ( ) // Predefined because it's needed by the Decoder -var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var tWireType = mustGetTypeInfo(reflect.TypeOf(wireType{})).id var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) func init() { // Some magic numbers to make sure there are no surprises. checkId(16, tWireType) - checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id) - checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id) - checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id) - checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id) - checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) - checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id) + checkId(17, mustGetTypeInfo(reflect.TypeOf(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.TypeOf(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.TypeOf(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.TypeOf(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.TypeOf(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.TypeOf(mapType{})).id) builtinIdToType = make(map[typeId]gobType) for k, v := range idToType { @@ -268,7 +268,7 @@ func init() { } nextId = firstUserId registerBasics() - wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) + wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil))) } // Array type @@ -569,7 +569,7 @@ func checkId(want, got typeId) { // used for building the basic types; called only from init(). the incoming // interface always refers to a pointer. func bootstrapType(name string, e interface{}, expect typeId) typeId { - rt := reflect.Typeof(e).Elem() + rt := reflect.TypeOf(e).Elem() _, present := types[rt] if present { panic("bootstrap type already present: " + name + ", " + rt.String()) @@ -723,7 +723,7 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } - base := userType(reflect.Typeof(value)).base + base := userType(reflect.TypeOf(value)).base // Check for incompatible duplicates. if t, ok := nameToConcreteType[name]; ok && t != base { panic("gob: registering duplicate types for " + name) @@ -732,7 +732,7 @@ func RegisterName(name string, value interface{}) { panic("gob: registering duplicate names for " + base.String()) } // Store the name and type provided by the user.... - nameToConcreteType[name] = reflect.Typeof(value) + nameToConcreteType[name] = reflect.TypeOf(value) // but the flattened type in the type table, since that's what decode needs. concreteTypeToName[base] = name } @@ -745,7 +745,7 @@ func RegisterName(name string, value interface{}) { // between types and names is not a bijection. func Register(value interface{}) { // Default to printed representation for unnamed types - rt := reflect.Typeof(value) + rt := reflect.TypeOf(value) name := rt.String() // But for named types (or pointers to them), qualify with import path. diff --git a/src/cmd/gofix/typecheck.go b/src/cmd/gofix/typecheck.go index d565e7b4b..2d81b9710 100644 --- a/src/cmd/gofix/typecheck.go +++ b/src/cmd/gofix/typecheck.go @@ -259,7 +259,7 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { if n == nil { return } - if false && reflect.Typeof(n).Kind() == reflect.Ptr { // debugging trace + if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace defer func() { if t := typeof[n]; t != "" { pos := fset.Position(n.(ast.Node).Pos()) @@ -375,6 +375,11 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { typeof[n] = gofmt(n.Args[0]) return } + // new(T) has type *T + if isTopName(n.Fun, "new") && len(n.Args) == 1 { + typeof[n] = "*" + gofmt(n.Args[0]) + return + } // Otherwise, use type of function to determine arguments. t := typeof[n.Fun] in, out := splitFunc(t) diff --git a/src/cmd/gofmt/doc.go b/src/cmd/gofmt/doc.go index e44030eee..1373b2657 100644 --- a/src/cmd/gofmt/doc.go +++ b/src/cmd/gofmt/doc.go @@ -8,29 +8,37 @@ Gofmt formats Go programs. Without an explicit path, it processes the standard input. Given a file, it operates on that file; given a directory, it operates on all .go files in that directory, recursively. (Files starting with a period are ignored.) +By default, gofmt prints the reformatted sources to standard output. Usage: gofmt [flags] [path ...] The flags are: + -d + Do not print reformatted sources to standard output. + If a file's formatting is different than gofmt's, print diffs + to standard output. -l - just list files whose formatting differs from gofmt's; - generate no other output unless -w is also set. + Do not print reformatted sources to standard output. + If a file's formatting is different from gofmt's, print its name + to standard output. -r rule - apply the rewrite rule to the source before reformatting. + Apply the rewrite rule to the source before reformatting. -s - try to simplify code (after applying the rewrite rule, if any). + Try to simplify code (after applying the rewrite rule, if any). -w - if set, overwrite each input file with its output. + Do not print reformatted sources to standard output. + If a file's formatting is different from gofmt's, overwrite it + with gofmt's version. -comments=true - print comments; if false, all comments are elided from the output. + Print comments; if false, all comments are elided from the output. -spaces - align with spaces instead of tabs. + Align with spaces instead of tabs. -tabindent - indent with tabs independent of -spaces. + Indent with tabs independent of -spaces. -tabwidth=8 - tab width in spaces. + Tab width in spaces. The rewrite rule specified with the -r flag must be a string of the form: diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go index ce274aa21..5dd801d90 100644 --- a/src/cmd/gofmt/gofmt.go +++ b/src/cmd/gofmt/gofmt.go @@ -6,6 +6,7 @@ package main import ( "bytes" + "exec" "flag" "fmt" "go/ast" @@ -28,6 +29,7 @@ var ( write = flag.Bool("w", false, "write result to (source) file instead of stdout") rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')") simplifyAST = flag.Bool("s", false, "simplify code") + doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") // layout control comments = flag.Bool("comments", true, "print comments") @@ -134,9 +136,17 @@ func processFile(filename string, in io.Reader, out io.Writer) os.Error { return err } } + if *doDiff { + data, err := diff(src, res) + if err != nil { + return fmt.Errorf("computing diff: %s", err) + } + fmt.Printf("diff %s gofmt/%s\n", filename, filename) + out.Write(data) + } } - if !*list && !*write { + if !*list && !*write && !*doDiff { _, err = out.Write(res) } @@ -230,3 +240,37 @@ func gofmtMain() { } } } + + +func diff(b1, b2 []byte) (data []byte, err os.Error) { + f1, err := ioutil.TempFile("", "gofmt") + if err != nil { + return nil, err + } + defer os.Remove(f1.Name()) + defer f1.Close() + + f2, err := ioutil.TempFile("", "gofmt") + if err != nil { + return nil, err + } + defer os.Remove(f2.Name()) + defer f2.Close() + + f1.Write(b1) + f2.Write(b2) + + diffcmd, err := exec.LookPath("diff") + if err != nil { + return nil, err + } + + c, err := exec.Run(diffcmd, []string{"diff", "-u", f1.Name(), f2.Name()}, + nil, "", exec.DevNull, exec.Pipe, exec.MergeWithStdout) + if err != nil { + return nil, err + } + defer c.Close() + + return ioutil.ReadAll(c.Stdout) +} diff --git a/src/cmd/gofmt/gofmt_test.go b/src/cmd/gofmt/gofmt_test.go index 4ec94e293..a72530307 100644 --- a/src/cmd/gofmt/gofmt_test.go +++ b/src/cmd/gofmt/gofmt_test.go @@ -71,6 +71,7 @@ var tests = []struct { {".", "gofmt_test.go", "gofmt_test.go", ""}, {"testdata", "composites.input", "composites.golden", "-s"}, {"testdata", "rewrite1.input", "rewrite1.golden", "-r=Foo->Bar"}, + {"testdata", "rewrite2.input", "rewrite2.golden", "-r=int->bool"}, } diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index 93643dced..4c24282f3 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -19,6 +19,7 @@ import ( func initRewrite() { if *rewriteRule == "" { + rewrite = nil // disable any previous rewrite return } f := strings.Split(*rewriteRule, "->", -1) @@ -59,26 +60,34 @@ func dump(msg string, val reflect.Value) { // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { m := make(map[string]reflect.Value) - pat := reflect.NewValue(pattern) - repl := reflect.NewValue(replace) + pat := reflect.ValueOf(pattern) + repl := reflect.ValueOf(replace) var f func(val reflect.Value) reflect.Value // f is recursive f = func(val reflect.Value) reflect.Value { + // don't bother if val is invalid to start with + if !val.IsValid() { + return reflect.Value{} + } for k := range m { m[k] = reflect.Value{}, false } val = apply(f, val) if match(m, pat, val) { - val = subst(m, repl, reflect.NewValue(val.Interface().(ast.Node).Pos())) + val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos())) } return val } - return apply(f, reflect.NewValue(p)).Interface().(*ast.File) + return apply(f, reflect.ValueOf(p)).Interface().(*ast.File) } // setValue is a wrapper for x.SetValue(y); it protects // the caller from panics if x cannot be changed to y. func setValue(x, y reflect.Value) { + // don't bother if y is invalid to start with + if !y.IsValid() { + return + } defer func() { if x := recover(); x != nil { if s, ok := x.(string); ok && strings.HasPrefix(s, "type mismatch") { @@ -94,11 +103,13 @@ func setValue(x, y reflect.Value) { // Values/types for special cases. var ( - objectPtrNil = reflect.NewValue((*ast.Object)(nil)) + objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) + scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) - identType = reflect.Typeof((*ast.Ident)(nil)) - objectPtrType = reflect.Typeof((*ast.Object)(nil)) - positionType = reflect.Typeof(token.NoPos) + identType = reflect.TypeOf((*ast.Ident)(nil)) + objectPtrType = reflect.TypeOf((*ast.Object)(nil)) + positionType = reflect.TypeOf(token.NoPos) + scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) ) @@ -115,6 +126,12 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value return objectPtrNil } + // similarly for scopes: they are likely incorrect after a rewrite; + // replace them with nil + if val.Type() == scopePtrType { + return scopePtrNil + } + switch v := reflect.Indirect(val); v.Kind() { case reflect.Slice: for i := 0; i < v.Len(); i++ { @@ -259,21 +276,21 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) return v case reflect.Struct: - v := reflect.Zero(p.Type()) + v := reflect.New(p.Type()).Elem() for i := 0; i < p.NumField(); i++ { v.Field(i).Set(subst(m, p.Field(i), pos)) } return v case reflect.Ptr: - v := reflect.Zero(p.Type()) + v := reflect.New(p.Type()).Elem() if elem := p.Elem(); elem.IsValid() { v.Set(subst(m, elem, pos).Addr()) } return v case reflect.Interface: - v := reflect.Zero(p.Type()) + v := reflect.New(p.Type()).Elem() if elem := p.Elem(); elem.IsValid() { v.Set(subst(m, elem, pos)) } diff --git a/src/cmd/gofmt/simplify.go b/src/cmd/gofmt/simplify.go index bcc67c4a6..40a9f8f17 100644 --- a/src/cmd/gofmt/simplify.go +++ b/src/cmd/gofmt/simplify.go @@ -26,7 +26,7 @@ func (s *simplifier) Visit(node ast.Node) ast.Visitor { } if eltType != nil { - typ := reflect.NewValue(eltType) + typ := reflect.ValueOf(eltType) for _, x := range outer.Elts { // look at value of indexed/named elements if t, ok := x.(*ast.KeyValueExpr); ok { @@ -37,7 +37,7 @@ func (s *simplifier) Visit(node ast.Node) ast.Visitor { // matches the outer literal's element type exactly, the inner // literal type may be omitted if inner, ok := x.(*ast.CompositeLit); ok { - if match(nil, typ, reflect.NewValue(inner.Type)) { + if match(nil, typ, reflect.ValueOf(inner.Type)) { inner.Type = nil } } diff --git a/src/cmd/gofmt/test.sh b/src/cmd/gofmt/test.sh index 3340c48f0..99ec76932 100755 --- a/src/cmd/gofmt/test.sh +++ b/src/cmd/gofmt/test.sh @@ -36,7 +36,7 @@ apply1() { # the following files are skipped because they are test cases # for syntax errors and thus won't parse in the first place: case `basename "$F"` in - func3.go | const2.go | char_lit1.go | \ + func3.go | const2.go | char_lit1.go | blank1.go | \ bug014.go | bug050.go | bug068.go | bug083.go | bug088.go | \ bug106.go | bug121.go | bug125.go | bug133.go | bug160.go | \ bug163.go | bug166.go | bug169.go | bug217.go | bug222.go | \ diff --git a/src/cmd/gofmt/testdata/rewrite1.golden b/src/cmd/gofmt/testdata/rewrite1.golden index 3f909ff4a..d9beb3705 100644 --- a/src/cmd/gofmt/testdata/rewrite1.golden +++ b/src/cmd/gofmt/testdata/rewrite1.golden @@ -1,3 +1,7 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package main type Bar int diff --git a/src/cmd/gofmt/testdata/rewrite1.input b/src/cmd/gofmt/testdata/rewrite1.input index 1f10e3601..bdb894320 100644 --- a/src/cmd/gofmt/testdata/rewrite1.input +++ b/src/cmd/gofmt/testdata/rewrite1.input @@ -1,3 +1,7 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package main type Foo int diff --git a/src/cmd/gofmt/testdata/rewrite2.golden b/src/cmd/gofmt/testdata/rewrite2.golden new file mode 100644 index 000000000..64c67ffa6 --- /dev/null +++ b/src/cmd/gofmt/testdata/rewrite2.golden @@ -0,0 +1,10 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +// Slices have nil Len values in the corresponding ast.ArrayType +// node and reflect.NewValue(slice.Len) is an invalid reflect.Value. +// The rewriter must not crash in that case. Was issue 1696. +func f() []bool {} diff --git a/src/cmd/gofmt/testdata/rewrite2.input b/src/cmd/gofmt/testdata/rewrite2.input new file mode 100644 index 000000000..21171447a --- /dev/null +++ b/src/cmd/gofmt/testdata/rewrite2.input @@ -0,0 +1,10 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +// Slices have nil Len values in the corresponding ast.ArrayType +// node and reflect.NewValue(slice.Len) is an invalid reflect.Value. +// The rewriter must not crash in that case. Was issue 1696. +func f() []int {} diff --git a/src/cmd/goinstall/Makefile b/src/cmd/goinstall/Makefile index aaf202ee7..202797cd5 100644 --- a/src/cmd/goinstall/Makefile +++ b/src/cmd/goinstall/Makefile @@ -10,6 +10,7 @@ GOFILES=\ main.go\ make.go\ parse.go\ + path.go\ syslist.go\ CLEANFILES+=syslist.go diff --git a/src/cmd/goinstall/download.go b/src/cmd/goinstall/download.go index 88befc0dc..7dad596ab 100644 --- a/src/cmd/goinstall/download.go +++ b/src/cmd/goinstall/download.go @@ -37,15 +37,15 @@ var bitbucket = regexp.MustCompile(`^(bitbucket\.org/[a-z0-9A-Z_.\-]+/[a-z0-9A-Z var launchpad = regexp.MustCompile(`^(launchpad\.net/([a-z0-9A-Z_.\-]+(/[a-z0-9A-Z_.\-]+)?|~[a-z0-9A-Z_.\-]+/(\+junk|[a-z0-9A-Z_.\-]+)/[a-z0-9A-Z_.\-]+))(/[a-z0-9A-Z_.\-/]+)?$`) // download checks out or updates pkg from the remote server. -func download(pkg string) (string, os.Error) { +func download(pkg, srcDir string) os.Error { if strings.Contains(pkg, "..") { - return "", os.ErrorString("invalid path (contains ..)") + return os.ErrorString("invalid path (contains ..)") } if m := bitbucket.FindStringSubmatch(pkg); m != nil { - if err := vcsCheckout(&hg, m[1], "http://"+m[1], m[1]); err != nil { - return "", err + if err := vcsCheckout(&hg, srcDir, m[1], "http://"+m[1], m[1]); err != nil { + return err } - return root + pkg, nil + return nil } if m := googlecode.FindStringSubmatch(pkg); m != nil { var v *vcs @@ -58,29 +58,29 @@ func download(pkg string) (string, os.Error) { // regexp only allows hg, svn to get through panic("missing case in download: " + pkg) } - if err := vcsCheckout(v, m[1], "https://"+m[1], m[1]); err != nil { - return "", err + if err := vcsCheckout(v, srcDir, m[1], "https://"+m[1], m[1]); err != nil { + return err } - return root + pkg, nil + return nil } if m := github.FindStringSubmatch(pkg); m != nil { if strings.HasSuffix(m[1], ".git") { - return "", os.ErrorString("repository " + pkg + " should not have .git suffix") + return os.ErrorString("repository " + pkg + " should not have .git suffix") } - if err := vcsCheckout(&git, m[1], "http://"+m[1]+".git", m[1]); err != nil { - return "", err + if err := vcsCheckout(&git, srcDir, m[1], "http://"+m[1]+".git", m[1]); err != nil { + return err } - return root + pkg, nil + return nil } if m := launchpad.FindStringSubmatch(pkg); m != nil { // Either lp.net/<project>[/<series>[/<path>]] // or lp.net/~<user or team>/<project>/<branch>[/<path>] - if err := vcsCheckout(&bzr, m[1], "https://"+m[1], m[1]); err != nil { - return "", err + if err := vcsCheckout(&bzr, srcDir, m[1], "https://"+m[1], m[1]); err != nil { + return err } - return root + pkg, nil + return nil } - return "", os.ErrorString("unknown repository: " + pkg) + return os.ErrorString("unknown repository: " + pkg) } // a vcs represents a version control system @@ -172,8 +172,8 @@ func (v *vcs) updateRepo(dst string) os.Error { // exists and -u was specified on the command line) // the repository at tag/branch "release". If there is no // such tag or branch, it falls back to the repository tip. -func vcsCheckout(vcs *vcs, pkgprefix, repo, dashpath string) os.Error { - dst := filepath.Join(root, filepath.FromSlash(pkgprefix)) +func vcsCheckout(vcs *vcs, srcDir, pkgprefix, repo, dashpath string) os.Error { + dst := filepath.Join(srcDir, filepath.FromSlash(pkgprefix)) dir, err := os.Stat(filepath.Join(dst, vcs.metadir)) if err == nil && !dir.IsDirectory() { return os.ErrorString("not a directory: " + dst) diff --git a/src/cmd/goinstall/main.go b/src/cmd/goinstall/main.go index 8fec8e312..6cd92907a 100644 --- a/src/cmd/goinstall/main.go +++ b/src/cmd/goinstall/main.go @@ -150,6 +150,7 @@ func install(pkg, parent string) { // Check whether package is local or remote. // If remote, download or update it. var dir string + proot := gopath[0] // default to GOROOT local := false if strings.HasPrefix(pkg, "http://") { fmt.Fprintf(os.Stderr, "%s: %s: 'http://' used in remote path, try '%s'\n", argv0, pkg, pkg[7:]) @@ -163,8 +164,9 @@ func install(pkg, parent string) { dir = filepath.Join(root, filepath.FromSlash(pkg)) local = true } else { - var err os.Error - dir, err = download(pkg) + proot = findPkgroot(pkg) + err := download(pkg, proot.srcDir()) + dir = filepath.Join(proot.srcDir(), pkg) if err != nil { fmt.Fprintf(os.Stderr, "%s: %s: %s\n", argv0, pkg, err) errors = true @@ -192,18 +194,11 @@ func install(pkg, parent string) { install(p, pkg) } } - if dirInfo.pkgName == "main" { - if !errors { - fmt.Fprintf(os.Stderr, "%s: %s's dependencies are installed.\n", argv0, pkg) - } - errors = true - visit[pkg] = done - return - } // Install this package. if !errors { - if err := domake(dir, pkg, local); err != nil { + isCmd := dirInfo.pkgName == "main" + if err := domake(dir, pkg, proot, local, isCmd); err != nil { fmt.Fprintf(os.Stderr, "%s: installing %s: %s\n", argv0, pkg, err) errors = true } else if !local && *logPkgs { diff --git a/src/cmd/goinstall/make.go b/src/cmd/goinstall/make.go index ceb119e5a..b2ca82b46 100644 --- a/src/cmd/goinstall/make.go +++ b/src/cmd/goinstall/make.go @@ -9,6 +9,7 @@ package main import ( "bytes" "os" + "path/filepath" "template" ) @@ -17,7 +18,7 @@ import ( // For non-local packages or packages without Makefiles, // domake generates a standard Makefile and passes it // to make on standard input. -func domake(dir, pkg string, local bool) (err os.Error) { +func domake(dir, pkg string, root *pkgroot, local, isCmd bool) (err os.Error) { needMakefile := true if local { _, err := os.Stat(dir + "/Makefile") @@ -28,7 +29,7 @@ func domake(dir, pkg string, local bool) (err os.Error) { cmd := []string{"gomake"} var makefile []byte if needMakefile { - if makefile, err = makeMakefile(dir, pkg); err != nil { + if makefile, err = makeMakefile(dir, pkg, root, isCmd); err != nil { return err } cmd = append(cmd, "-f-") @@ -43,11 +44,26 @@ func domake(dir, pkg string, local bool) (err os.Error) { // makeMakefile computes the standard Makefile for the directory dir // installing as package pkg. It includes all *.go files in the directory // except those in package main and those ending in _test.go. -func makeMakefile(dir, pkg string) ([]byte, os.Error) { +func makeMakefile(dir, pkg string, root *pkgroot, isCmd bool) ([]byte, os.Error) { if !safeName(pkg) { return nil, os.ErrorString("unsafe name: " + pkg) } - dirInfo, err := scanDir(dir, false) + targ := pkg + targDir := root.pkgDir() + if isCmd { + // use the last part of the package name only + _, targ = filepath.Split(pkg) + // if building the working dir use the directory name + if targ == "." { + d, err := filepath.Abs(dir) + if err != nil { + return nil, os.NewError("finding path: " + err.String()) + } + _, targ = filepath.Split(d) + } + targDir = root.binDir() + } + dirInfo, err := scanDir(dir, isCmd) if err != nil { return nil, err } @@ -94,7 +110,10 @@ func makeMakefile(dir, pkg string) ([]byte, os.Error) { } var buf bytes.Buffer - md := makedata{pkg, goFiles, oFiles, cgoFiles, cgoOFiles} + md := makedata{targ, targDir, "pkg", goFiles, oFiles, cgoFiles, cgoOFiles, imports} + if isCmd { + md.Type = "cmd" + } if err := makefileTemplate.Execute(&buf, &md); err != nil { return nil, err } @@ -104,6 +123,9 @@ func makeMakefile(dir, pkg string) ([]byte, os.Error) { var safeBytes = []byte("+-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz") func safeName(s string) bool { + if s == "" { + return false + } for i := 0; i < len(s); i++ { if c := s[i]; c < 0x80 && bytes.IndexByte(safeBytes, c) < 0 { return false @@ -114,17 +136,21 @@ func safeName(s string) bool { // makedata is the data type for the makefileTemplate. type makedata struct { - Pkg string // package import path + Targ string // build target + TargDir string // build target directory + Type string // build type: "pkg" or "cmd" GoFiles []string // list of non-cgo .go files OFiles []string // list of .$O files CgoFiles []string // list of cgo .go files CgoOFiles []string // list of cgo .o files, without extension + Imports []string // gc/ld import paths } var makefileTemplate = template.MustParse(` include $(GOROOT)/src/Make.inc -TARG={Pkg} +TARG={Targ} +TARGDIR={TargDir} {.section GoFiles} GOFILES=\ @@ -154,6 +180,9 @@ CGO_OFILES=\ {.end} {.end} -include $(GOROOT)/src/Make.pkg +GCIMPORTS={.repeated section Imports}-I "{@}" {.end} +LDIMPORTS={.repeated section Imports}-L "{@}" {.end} + +include $(GOROOT)/src/Make.{Type} `, nil) diff --git a/src/cmd/goinstall/parse.go b/src/cmd/goinstall/parse.go index 0e617903c..a4bb761f2 100644 --- a/src/cmd/goinstall/parse.go +++ b/src/cmd/goinstall/parse.go @@ -88,6 +88,9 @@ func scanDir(dir string, allowMain bool) (info *dirInfo, err os.Error) { if s == "main" && !allowMain { continue } + if s == "documentation" { + continue + } if pkgName == "" { pkgName = s } else if pkgName != s { diff --git a/src/cmd/goinstall/path.go b/src/cmd/goinstall/path.go new file mode 100644 index 000000000..1153e0471 --- /dev/null +++ b/src/cmd/goinstall/path.go @@ -0,0 +1,117 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "log" + "os" + "path/filepath" + "runtime" +) + +var ( + gopath []*pkgroot + imports []string + defaultRoot *pkgroot // default root for remote packages +) + +// set up gopath: parse and validate GOROOT and GOPATH variables +func init() { + p, err := newPkgroot(root) + if err != nil { + log.Fatalf("Invalid GOROOT %q: %v", root, err) + } + p.goroot = true + gopath = []*pkgroot{p} + + for _, p := range filepath.SplitList(os.Getenv("GOPATH")) { + if p == "" { + continue + } + r, err := newPkgroot(p) + if err != nil { + log.Printf("Invalid GOPATH %q: %v", p, err) + continue + } + gopath = append(gopath, r) + imports = append(imports, r.pkgDir()) + + // select first GOPATH entry as default + if defaultRoot == nil { + defaultRoot = r + } + } + + // use GOROOT if no valid GOPATH specified + if defaultRoot == nil { + defaultRoot = gopath[0] + } +} + +type pkgroot struct { + path string + goroot bool // TODO(adg): remove this once Go tree re-organized +} + +func newPkgroot(p string) (*pkgroot, os.Error) { + if !filepath.IsAbs(p) { + return nil, os.NewError("must be absolute") + } + ep, err := filepath.EvalSymlinks(p) + if err != nil { + return nil, err + } + return &pkgroot{path: ep}, nil +} + +func (r *pkgroot) srcDir() string { + if r.goroot { + return filepath.Join(r.path, "src", "pkg") + } + return filepath.Join(r.path, "src") +} + +func (r *pkgroot) pkgDir() string { + goos, goarch := runtime.GOOS, runtime.GOARCH + if e := os.Getenv("GOOS"); e != "" { + goos = e + } + if e := os.Getenv("GOARCH"); e != "" { + goarch = e + } + return filepath.Join(r.path, "pkg", goos+"_"+goarch) +} + +func (r *pkgroot) binDir() string { + return filepath.Join(r.path, "bin") +} + +func (r *pkgroot) hasSrcDir(name string) bool { + fi, err := os.Stat(filepath.Join(r.srcDir(), name)) + if err != nil { + return false + } + return fi.IsDirectory() +} + +func (r *pkgroot) hasPkg(name string) bool { + fi, err := os.Stat(filepath.Join(r.pkgDir(), name+".a")) + if err != nil { + return false + } + return fi.IsRegular() + // TODO(adg): check object version is consistent +} + +// findPkgroot searches each of the gopath roots +// for the source code for the given import path. +func findPkgroot(importPath string) *pkgroot { + for _, r := range gopath { + if r.hasSrcDir(importPath) { + return r + } + } + return defaultRoot +} diff --git a/src/cmd/gopack/ar.c b/src/cmd/gopack/ar.c index dc3899f37..017978ced 100644 --- a/src/cmd/gopack/ar.c +++ b/src/cmd/gopack/ar.c @@ -144,6 +144,7 @@ char *file; /* current file or member being worked on */ Biobuf bout; Biobuf bar; char *prefix; +int pkgdefsafe; /* was __.PKGDEF marked safe? */ void arcopy(Biobuf*, Arfile*, Armember*); int arcreate(char*); @@ -177,6 +178,7 @@ void scanpkg(Biobuf*, long); void select(int*, long); void setcom(void(*)(char*, int, char**)); void skip(Biobuf*, vlong); +void checksafe(Biobuf*, vlong); int symcomp(void*, void*); void trim(char*, char*, int); void usage(void); @@ -322,9 +324,9 @@ rcmd(char *arname, int count, char **files) skip(&bar, bp->size); continue; } - /* pitch pkgdef file */ + /* pitch pkgdef file but remember whether it was marked safe */ if (gflag && strcmp(file, pkgdef) == 0) { - skip(&bar, bp->size); + checksafe(&bar, bp->size); continue; } /* @@ -773,7 +775,8 @@ scanpkg(Biobuf *b, long size) goto foundstart; } // fprint(2, "gopack: warning: no package import section in %s\n", file); - safe = 0; // non-Go file (C or assembly) + if(b != &bar || !pkgdefsafe) + safe = 0; // non-Go file (C or assembly) return; foundstart: @@ -807,7 +810,7 @@ foundstart: pkgname = armalloc(pkg - data + 1); memmove(pkgname, data, pkg - data); pkgname[pkg-data] = '\0'; - if(strcmp(pkg, " safe\n") != 0) + if(strcmp(pkg, " safe\n") != 0 && (b != &bar || !pkgdefsafe)) safe = 0; start = Boffset(b); // after package statement first = 0; @@ -1094,6 +1097,36 @@ skip(Biobuf *bp, vlong len) Bseek(bp, len, 1); } +void +checksafe(Biobuf *bp, vlong len) +{ + char *p; + vlong end; + + if (len & 01) + len++; + end = Boffset(bp) + len; + + p = Brdline(bp, '\n'); + if(p == nil || strncmp(p, "go object ", 10) != 0) + goto done; + for(;;) { + p = Brdline(bp, '\n'); + if(p == nil || Boffset(bp) >= end) + goto done; + if(strncmp(p, "$$\n", 3) == 0) + break; + } + p = Brdline(bp, '\n'); + if(p == nil || Boffset(bp) > end) + goto done; + if(Blinelen(bp) > 8+6 && strncmp(p, "package ", 8) == 0 && strncmp(p+Blinelen(bp)-6, " safe\n", 6) == 0) + pkgdefsafe = 1; + +done: + Bseek(bp, end, 0); +} + /* * Stream the three temp files to an archive */ @@ -1676,6 +1709,10 @@ arread_cutprefix(Biobuf *b, Armember *bp) offset = o; } } + } else { + // didn't find the whole prefix. + // give up and let it emit the entire name. + inprefix = nil; } // copy instructions diff --git a/src/cmd/gotest/gotest.go b/src/cmd/gotest/gotest.go index 138216e68..a7ba8dd11 100644 --- a/src/cmd/gotest/gotest.go +++ b/src/cmd/gotest/gotest.go @@ -16,6 +16,7 @@ import ( "path/filepath" "runtime" "strings" + "time" "unicode" "utf8" ) @@ -51,6 +52,13 @@ var ( xFlag bool ) +// elapsed returns time elapsed since gotest started. +func elapsed() float64 { + return float64(time.Nanoseconds()-start) / 1e9 +} + +var start = time.Nanoseconds() + // File represents a file that contains tests. type File struct { name string @@ -80,6 +88,9 @@ func main() { if !cFlag { runTestWithArgs("./" + O + ".out") } + if xFlag { + fmt.Printf("gotest %.2fs: done\n", elapsed()) + } } // needMakefile tests that we have a Makefile in this directory. @@ -119,7 +130,10 @@ func setEnvironment() { // Basic environment. GOROOT = runtime.GOROOT() addEnv("GOROOT", GOROOT) - GOARCH = runtime.GOARCH + GOARCH = os.Getenv("GOARCH") + if GOARCH == "" { + GOARCH = runtime.GOARCH + } addEnv("GOARCH", GOARCH) O = theChar[GOARCH] if O == "" { @@ -254,7 +268,12 @@ func runTestWithArgs(binary string) { // retrieve standard output. func doRun(argv []string, returnStdout bool) string { if xFlag { - fmt.Printf("gotest: %s\n", strings.Join(argv, " ")) + fmt.Printf("gotest %.2fs: %s\n", elapsed(), strings.Join(argv, " ")) + t := -time.Nanoseconds() + defer func() { + t += time.Nanoseconds() + fmt.Printf(" [+%.2fs]\n", float64(t)/1e9) + }() } command := argv[0] if runtime.GOOS == "windows" && command == "gomake" { @@ -266,7 +285,7 @@ func doRun(argv []string, returnStdout bool) string { } cmd += `"` + v + `"` } - argv = []string{"cmd", "/c", "sh", "-c", cmd} + argv = []string{"sh", "-c", cmd} } var err os.Error argv[0], err = exec.LookPath(argv[0]) @@ -359,7 +378,7 @@ func writeTestmainGo() { fmt.Fprintf(b, "import %q\n", "./_xtest_") } fmt.Fprintf(b, "import %q\n", "testing") - fmt.Fprintf(b, "import __os__ %q\n", "os") // rename in case tested package is called os + fmt.Fprintf(b, "import __os__ %q\n", "os") // rename in case tested package is called os fmt.Fprintf(b, "import __regexp__ %q\n", "regexp") // rename in case tested package is called regexp fmt.Fprintln(b) // for gofmt @@ -374,7 +393,7 @@ func writeTestmainGo() { fmt.Fprintln(b) // Benchmarks. - fmt.Fprintln(b, "var benchmarks = []testing.InternalBenchmark{") + fmt.Fprintf(b, "var benchmarks = []testing.InternalBenchmark{") for _, f := range files { for _, bm := range f.benchmarks { fmt.Fprintf(b, "\t{\"%s.%s\", %s.%s},\n", f.pkg, bm, notMain(f.pkg), bm) diff --git a/src/cmd/ld/data.c b/src/cmd/ld/data.c index d27416dac..0cb2b2138 100644 --- a/src/cmd/ld/data.c +++ b/src/cmd/ld/data.c @@ -312,7 +312,7 @@ symgrow(Sym *s, int32 siz) } void -savedata(Sym *s, Prog *p) +savedata(Sym *s, Prog *p, char *pn) { int32 off, siz, i, fl; uchar *cast; @@ -321,8 +321,10 @@ savedata(Sym *s, Prog *p) off = p->from.offset; siz = p->datasize; + if(off < 0 || siz < 0 || off >= 1<<30 || siz >= 100) + mangle(pn); symgrow(s, off+siz); - + switch(p->to.type) { default: diag("bad data: %P", p); @@ -876,7 +878,7 @@ textaddress(void) void address(void) { - Section *s, *text, *data, *rodata, *bss; + Section *s, *text, *data, *rodata; Sym *sym, *sub; uvlong va; @@ -911,7 +913,6 @@ address(void) text = segtext.sect; rodata = segtext.sect->next; data = segdata.sect; - bss = segdata.sect->next; for(sym = datap; sym != nil; sym = sym->next) { cursym = sym; diff --git a/src/cmd/ld/dwarf.c b/src/cmd/ld/dwarf.c index fa55fcbb4..98b068008 100644 --- a/src/cmd/ld/dwarf.c +++ b/src/cmd/ld/dwarf.c @@ -1376,20 +1376,18 @@ synthesizemaptypes(DWDie *die) static void synthesizechantypes(DWDie *die) { - DWDie *sudog, *waitq, *link, *hchan, + DWDie *sudog, *waitq, *hchan, *dws, *dww, *dwh, *elemtype; DWAttr *a; - int elemsize, linksize, sudogsize; + int elemsize, sudogsize; sudog = defgotype(lookup_or_diag("type.runtime.sudog")); waitq = defgotype(lookup_or_diag("type.runtime.waitq")); - link = defgotype(lookup_or_diag("type.runtime.link")); hchan = defgotype(lookup_or_diag("type.runtime.hchan")); - if (sudog == nil || waitq == nil || link == nil || hchan == nil) + if (sudog == nil || waitq == nil || hchan == nil) return; sudogsize = getattr(sudog, DW_AT_byte_size)->value; - linksize = getattr(link, DW_AT_byte_size)->value; for (; die != nil; die = die->link) { if (die->abbrev != DW_ABRV_CHANTYPE) @@ -1422,7 +1420,7 @@ synthesizechantypes(DWDie *die) copychildren(dwh, hchan); substitutetype(dwh, "recvq", dww); substitutetype(dwh, "sendq", dww); - substitutetype(dwh, "free", dws); + substitutetype(dwh, "free", defptrto(dws)); newattr(dwh, DW_AT_byte_size, DW_CLS_CONSTANT, getattr(hchan, DW_AT_byte_size)->value, NULL); @@ -2569,12 +2567,8 @@ dwarfaddpeheaders(void) newPEDWARFSection(".debug_line", linesize); newPEDWARFSection(".debug_frame", framesize); newPEDWARFSection(".debug_info", infosize); - if (pubnamessize > 0) - newPEDWARFSection(".debug_pubnames", pubnamessize); - if (pubtypessize > 0) - newPEDWARFSection(".debug_pubtypes", pubtypessize); - if (arangessize > 0) - newPEDWARFSection(".debug_aranges", arangessize); - if (gdbscriptsize > 0) - newPEDWARFSection(".debug_gdb_scripts", gdbscriptsize); + newPEDWARFSection(".debug_pubnames", pubnamessize); + newPEDWARFSection(".debug_pubtypes", pubtypessize); + newPEDWARFSection(".debug_aranges", arangessize); + newPEDWARFSection(".debug_gdb_scripts", gdbscriptsize); } diff --git a/src/cmd/ld/elf.c b/src/cmd/ld/elf.c index b0cce4985..fc917b203 100644 --- a/src/cmd/ld/elf.c +++ b/src/cmd/ld/elf.c @@ -331,17 +331,62 @@ elfinterp(ElfShdr *sh, uint64 startva, char *p) } extern int nelfsym; +int elfverneed; + +typedef struct Elfaux Elfaux; +typedef struct Elflib Elflib; + +struct Elflib +{ + Elflib *next; + Elfaux *aux; + char *file; +}; + +struct Elfaux +{ + Elfaux *next; + int num; + char *vers; +}; + +Elfaux* +addelflib(Elflib **list, char *file, char *vers) +{ + Elflib *lib; + Elfaux *aux; + + for(lib=*list; lib; lib=lib->next) + if(strcmp(lib->file, file) == 0) + goto havelib; + lib = mal(sizeof *lib); + lib->next = *list; + lib->file = file; + *list = lib; +havelib: + for(aux=lib->aux; aux; aux=aux->next) + if(strcmp(aux->vers, vers) == 0) + goto haveaux; + aux = mal(sizeof *aux); + aux->next = lib->aux; + aux->vers = vers; + lib->aux = aux; +haveaux: + return aux; +} void elfdynhash(void) { - Sym *s, *sy; - int i, nbucket, b; - uchar *pc; - uint32 hc, g; - uint32 *chain, *buckets; + Sym *s, *sy, *dynstr; + int i, j, nbucket, b, nfile; + uint32 hc, *chain, *buckets; int nsym; char *name; + Elfaux **need; + Elflib *needlib; + Elflib *l; + Elfaux *x; if(!iself) return; @@ -358,29 +403,29 @@ elfdynhash(void) i >>= 1; } - chain = malloc(nsym * sizeof(uint32)); - buckets = malloc(nbucket * sizeof(uint32)); - if(chain == nil || buckets == nil) { + needlib = nil; + need = malloc(nsym * sizeof need[0]); + chain = malloc(nsym * sizeof chain[0]); + buckets = malloc(nbucket * sizeof buckets[0]); + if(need == nil || chain == nil || buckets == nil) { cursym = nil; diag("out of memory"); errorexit(); } - memset(chain, 0, nsym * sizeof(uint32)); - memset(buckets, 0, nbucket * sizeof(uint32)); + memset(need, 0, nsym * sizeof need[0]); + memset(chain, 0, nsym * sizeof chain[0]); + memset(buckets, 0, nbucket * sizeof buckets[0]); for(sy=allsym; sy!=S; sy=sy->allsym) { if (sy->dynid <= 0) continue; - hc = 0; + if(sy->dynimpvers) + need[sy->dynid] = addelflib(&needlib, sy->dynimplib, sy->dynimpvers); + name = sy->dynimpname; if(name == nil) name = sy->name; - for(pc = (uchar*)name; *pc; pc++) { - hc = (hc<<4) + *pc; - g = hc & 0xf0000000; - hc ^= g >> 24; - hc &= ~g; - } + hc = elfhash((uchar*)name); b = hc % nbucket; chain[sy->dynid] = buckets[b]; @@ -396,8 +441,62 @@ elfdynhash(void) free(chain); free(buckets); + + // version symbols + dynstr = lookup(".dynstr", 0); + s = lookup(".gnu.version_r", 0); + i = 2; + nfile = 0; + for(l=needlib; l; l=l->next) { + nfile++; + // header + adduint16(s, 1); // table version + j = 0; + for(x=l->aux; x; x=x->next) + j++; + adduint16(s, j); // aux count + adduint32(s, addstring(dynstr, l->file)); // file string offset + adduint32(s, 16); // offset from header to first aux + if(l->next) + adduint32(s, 16+j*16); // offset from this header to next + else + adduint32(s, 0); + + for(x=l->aux; x; x=x->next) { + x->num = i++; + // aux struct + adduint32(s, elfhash((uchar*)x->vers)); // hash + adduint16(s, 0); // flags + adduint16(s, x->num); // other - index we refer to this by + adduint32(s, addstring(dynstr, x->vers)); // version string offset + if(x->next) + adduint32(s, 16); // offset from this aux to next + else + adduint32(s, 0); + } + } + + // version references + s = lookup(".gnu.version", 0); + for(i=0; i<nsym; i++) { + if(i == 0) + adduint16(s, 0); // first entry - no symbol + else if(need[i] == nil) + adduint16(s, 1); // global + else + adduint16(s, need[i]->num); + } - elfwritedynent(lookup(".dynamic", 0), DT_NULL, 0); + free(need); + + s = lookup(".dynamic", 0); + elfverneed = nfile; + if(elfverneed) { + elfwritedynentsym(s, DT_VERNEED, lookup(".gnu.version_r", 0)); + elfwritedynent(s, DT_VERNEEDNUM, nfile); + elfwritedynentsym(s, DT_VERSYM, lookup(".gnu.version", 0)); + } + elfwritedynent(s, DT_NULL, 0); } ElfPhdr* diff --git a/src/cmd/ld/elf.h b/src/cmd/ld/elf.h index b27ae679b..08583cc8f 100644 --- a/src/cmd/ld/elf.h +++ b/src/cmd/ld/elf.h @@ -216,6 +216,9 @@ typedef struct { #define SHT_SYMTAB_SHNDX 18 /* Section indexes (see SHN_XINDEX). */ #define SHT_LOOS 0x60000000 /* First of OS specific semantics */ #define SHT_HIOS 0x6fffffff /* Last of OS specific semantics */ +#define SHT_GNU_VERDEF 0x6ffffffd +#define SHT_GNU_VERNEED 0x6ffffffe +#define SHT_GNU_VERSYM 0x6fffffff #define SHT_LOPROC 0x70000000 /* reserved range for processor */ #define SHT_HIPROC 0x7fffffff /* specific section header types */ #define SHT_LOUSER 0x80000000 /* reserved range for application */ @@ -311,6 +314,10 @@ typedef struct { #define DT_LOPROC 0x70000000 /* First processor-specific type. */ #define DT_HIPROC 0x7fffffff /* Last processor-specific type. */ +#define DT_VERNEED 0x6ffffffe +#define DT_VERNEEDNUM 0x6fffffff +#define DT_VERSYM 0x6ffffff0 + /* Values for DT_FLAGS */ #define DF_ORIGIN 0x0001 /* Indicates that the object being loaded may make reference to the $ORIGIN substitution @@ -962,12 +969,14 @@ uint64 endelf(void); extern int numelfphdr; extern int numelfshdr; extern int iself; +extern int elfverneed; int elfwriteinterp(void); void elfinterp(ElfShdr*, uint64, char*); void elfdynhash(void); ElfPhdr* elfphload(Segment*); ElfShdr* elfshbits(Section*); void elfsetstring(char*, int); +void elfaddverneed(Sym*); /* * Total amount of space to reserve at the start of the file diff --git a/src/cmd/ld/go.c b/src/cmd/ld/go.c index 055163d08..e52c5cb34 100644 --- a/src/cmd/ld/go.c +++ b/src/cmd/ld/go.c @@ -412,7 +412,7 @@ parsemethod(char **pp, char *ep, char **methp) static void loaddynimport(char *file, char *pkg, char *p, int n) { - char *pend, *next, *name, *def, *p0, *lib; + char *pend, *next, *name, *def, *p0, *lib, *q; Sym *s; pend = p + n; @@ -445,6 +445,12 @@ loaddynimport(char *file, char *pkg, char *p, int n) *strchr(name, ' ') = 0; *strchr(def, ' ') = 0; + if(debug['d']) { + fprint(2, "%s: %s: cannot use dynamic imports with -d flag\n", argv0, file); + nerrors++; + return; + } + if(strcmp(name, "_") == 0 && strcmp(def, "_") == 0) { // allow #pragma dynimport _ _ "foo.so" // to force a link of foo.so. @@ -453,17 +459,21 @@ loaddynimport(char *file, char *pkg, char *p, int n) } name = expandpkg(name, pkg); + q = strchr(def, '@'); + if(q) + *q++ = '\0'; s = lookup(name, 0); if(s->type == 0 || s->type == SXREF) { s->dynimplib = lib; s->dynimpname = def; + s->dynimpvers = q; s->type = SDYNIMPORT; } } return; err: - fprint(2, "%s: invalid dynimport line: %s\n", argv0, p0); + fprint(2, "%s: %s: invalid dynimport line: %s\n", argv0, file, p0); nerrors++; } diff --git a/src/cmd/ld/ldelf.c b/src/cmd/ld/ldelf.c index 44bbe68ee..d61020e49 100644 --- a/src/cmd/ld/ldelf.c +++ b/src/cmd/ld/ldelf.c @@ -319,7 +319,7 @@ ldelf(Biobuf *f, char *pkg, int64 len, char *pn) char *name; int i, j, rela, is64, n; uchar hdrbuf[64]; - uchar *p, *dp; + uchar *p; ElfHdrBytes *hdr; ElfObj *obj; ElfSect *sect, *rsect; @@ -561,7 +561,6 @@ ldelf(Biobuf *f, char *pkg, int64 len, char *pn) n = rsect->size/(4+4*is64)/(2+rela); r = mal(n*sizeof r[0]); p = rsect->base; - dp = sect->base; for(j=0; j<n; j++) { add = 0; rp = &r[j]; diff --git a/src/cmd/ld/lib.c b/src/cmd/ld/lib.c index 8cd570463..15219ba11 100644 --- a/src/cmd/ld/lib.c +++ b/src/cmd/ld/lib.c @@ -438,7 +438,7 @@ ldobj(Biobuf *f, char *pkg, int64 len, char *pn, int whence) return; } t = smprint("%s %s %s", getgoos(), thestring, getgoversion()); - if(strcmp(line+10, t) != 0) { + if(strcmp(line+10, t) != 0 && !debug['f']) { diag("%s: object is [%s] expected [%s]", pn, line+10, t); free(t); return; @@ -1033,7 +1033,7 @@ mkfwd(void) Prog *p; int i; int32 dwn[LOG], cnt[LOG]; - Prog *lst[LOG], *last; + Prog *lst[LOG]; for(i=0; i<LOG; i++) { if(i == 0) @@ -1044,7 +1044,6 @@ mkfwd(void) lst[i] = P; } i = 0; - last = nil; for(cursym = textp; cursym != nil; cursym = cursym->next) { for(p = cursym->text; p != P; p = p->link) { if(p->link == P) { diff --git a/src/cmd/ld/lib.h b/src/cmd/ld/lib.h index 646aeb535..8b603a04a 100644 --- a/src/cmd/ld/lib.h +++ b/src/cmd/ld/lib.h @@ -173,7 +173,7 @@ void datblk(int32, int32); Sym* datsort(Sym*); void reloc(void); void relocsym(Sym*); -void savedata(Sym*, Prog*); +void savedata(Sym*, Prog*, char*); void symgrow(Sym*, int32); vlong addstring(Sym*, char*); vlong adduint32(Sym*, uint32); diff --git a/src/cmd/ld/macho.c b/src/cmd/ld/macho.c index c8d7c4a6d..01349bb10 100644 --- a/src/cmd/ld/macho.c +++ b/src/cmd/ld/macho.c @@ -12,10 +12,10 @@ static int macho64; static MachoHdr hdr; -static MachoLoad load[16]; +static MachoLoad *load; static MachoSeg seg[16]; static MachoDebug xdebug[16]; -static int nload, nseg, ndebug, nsect; +static int nload, mload, nseg, ndebug, nsect; void machoinit(void) @@ -43,11 +43,18 @@ newMachoLoad(uint32 type, uint32 ndata) { MachoLoad *l; - if(nload >= nelem(load)) { - diag("too many loads"); - errorexit(); + if(nload >= mload) { + if(mload == 0) + mload = 1; + else + mload *= 2; + load = realloc(load, mload*sizeof load[0]); + if(load == nil) { + diag("out of memory"); + errorexit(); + } } - + if(macho64 && (ndata & 1)) ndata++; @@ -342,11 +349,13 @@ asmbmacho(void) msect = newMachoSect(ms, "__data"); msect->addr = va+v; - msect->size = symaddr(lookup(".got", 0)) - msect->addr; msect->off = v; + msect->size = segdata.filelen; s = lookup(".got", 0); if(s->size > 0) { + msect->size = symaddr(s) - msect->addr; + msect = newMachoSect(ms, "__nl_symbol_ptr"); msect->addr = symaddr(s); msect->size = s->size; diff --git a/src/cmd/ld/pe.c b/src/cmd/ld/pe.c index 0d4240e36..d523ca9c5 100644 --- a/src/cmd/ld/pe.c +++ b/src/cmd/ld/pe.c @@ -415,6 +415,9 @@ newPEDWARFSection(char *name, vlong size) IMAGE_SECTION_HEADER *h; char s[8]; + if(size == 0) + return nil; + if(nextsymoff+strlen(name)+1 > sizeof(symnames)) { diag("pe string table is full"); errorexit(); diff --git a/src/cmd/ld/symtab.c b/src/cmd/ld/symtab.c index aefe0b1af..da698fcc0 100644 --- a/src/cmd/ld/symtab.c +++ b/src/cmd/ld/symtab.c @@ -140,29 +140,25 @@ void putplan9sym(Sym *x, char *s, int t, vlong addr, vlong size, int ver, Sym *go) { int i; - + switch(t) { case 'T': - case 't': case 'L': - case 'l': case 'D': - case 'd': case 'B': - case 'b': + if(ver) + t += 'a' - 'A'; case 'a': case 'p': - case 'f': case 'z': case 'Z': - case 'm': lputb(addr); cput(t+0x80); /* 0x80 is variable length */ - + if(t == 'z' || t == 'Z') { - cput(0); + cput(s[0]); for(i=1; s[i] != 0 || s[i+1] != 0; i += 2) { cput(s[i]); cput(s[i+1]); @@ -172,19 +168,17 @@ putplan9sym(Sym *x, char *s, int t, vlong addr, vlong size, int ver, Sym *go) i++; } else { /* skip the '<' in filenames */ - if(t=='f') + if(t == 'f') s++; - for(i=0; s[i]; i++) cput(s[i]); cput(0); } - symsize += 4 + 1 + i + 1; break; default: return; - }; + }; } void diff --git a/src/env.bash b/src/env.bash index c1055d561..ca3ecebe8 100644 --- a/src/env.bash +++ b/src/env.bash @@ -3,6 +3,16 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. +# If set to a Windows-style path convert to an MSYS-Unix +# one using the built-in shell commands. +if [[ "$GOROOT" == *:* ]]; then + GOROOT=$(cd "$GOROOT"; pwd) +fi + +if [[ "$GOBIN" == *:* ]]; then + GOBIN=$(cd "$GOBIN"; pwd) +fi + export GOROOT=${GOROOT:-$(cd ..; pwd)} if ! test -f "$GOROOT"/include/u.h diff --git a/src/lib9/create.c b/src/lib9/create.c index 59845ba91..d7023aea0 100644 --- a/src/lib9/create.c +++ b/src/lib9/create.c @@ -37,9 +37,8 @@ THE SOFTWARE. int p9create(char *path, int mode, ulong perm) { - int fd, umode, rclose, rdwr; + int fd, umode, rclose; - rdwr = mode&3; rclose = mode&ORCLOSE; mode &= ~ORCLOSE; diff --git a/src/libmach/executable.c b/src/libmach/executable.c index 33000ed07..e90334438 100644 --- a/src/libmach/executable.c +++ b/src/libmach/executable.c @@ -991,7 +991,6 @@ machdotout(int fd, Fhdr *fp, ExecHdr *hp) { uvlong (*swav)(uvlong); uint32 (*swal)(uint32); - ushort (*swab)(ushort); Machhdr *mp; MachCmd **cmd; MachSymSeg *symtab; @@ -1012,7 +1011,6 @@ machdotout(int fd, Fhdr *fp, ExecHdr *hp) return 0; } - swab = leswab; swal = leswal; swav = leswav; diff --git a/src/libmach/obj.c b/src/libmach/obj.c index 1ffe7a0ee..7d660787b 100644 --- a/src/libmach/obj.c +++ b/src/libmach/obj.c @@ -215,7 +215,7 @@ processprog(Prog *p, int doautos) { if(p->kind == aNone) return 1; - if(p->sym < 0 || p->sym >= NNAMES) + if((schar)p->sym < 0 || p->sym >= NNAMES) return 0; switch(p->kind) { diff --git a/src/pkg/Makefile b/src/pkg/Makefile index e45b39e86..b046064a6 100644 --- a/src/pkg/Makefile +++ b/src/pkg/Makefile @@ -100,6 +100,7 @@ DIRS=\ html\ http\ http/cgi\ + http/fcgi\ http/pprof\ http/httptest\ image\ @@ -120,6 +121,7 @@ DIRS=\ netchan\ os\ os/signal\ + os/user\ patch\ path\ path/filepath\ @@ -183,7 +185,6 @@ NOTEST+=\ hash\ http/pprof\ http/httptest\ - image/jpeg\ net/dict\ rand\ runtime/cgo\ @@ -202,11 +203,6 @@ NOTEST+=\ NOBENCH+=\ container/vector\ -# Disable tests that depend on an external network. -ifeq ($(DISABLE_NET_TESTS),1) -NOTEST+=net syslog -endif - # Disable tests that windows cannot run yet. ifeq ($(GOOS),windows) NOTEST+=os/signal # no signals diff --git a/src/pkg/archive/tar/common.go b/src/pkg/archive/tar/common.go index 5b781ff3d..528858765 100644 --- a/src/pkg/archive/tar/common.go +++ b/src/pkg/archive/tar/common.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The tar package implements access to tar archives. +// Package tar implements access to tar archives. // It aims to cover most of the variations, including those produced // by GNU and BSD tars. // diff --git a/src/pkg/archive/tar/reader.go b/src/pkg/archive/tar/reader.go index 0cfdf355d..ad06b6dac 100644 --- a/src/pkg/archive/tar/reader.go +++ b/src/pkg/archive/tar/reader.go @@ -10,6 +10,7 @@ package tar import ( "bytes" "io" + "io/ioutil" "os" "strconv" ) @@ -84,12 +85,6 @@ func (tr *Reader) octal(b []byte) int64 { return int64(x) } -type ignoreWriter struct{} - -func (ignoreWriter) Write(b []byte) (n int, err os.Error) { - return len(b), nil -} - // Skip any unread bytes in the existing file entry, as well as any alignment padding. func (tr *Reader) skipUnread() { nr := tr.nb + tr.pad // number of bytes to skip @@ -99,7 +94,7 @@ func (tr *Reader) skipUnread() { return } } - _, tr.err = io.Copyn(ignoreWriter{}, tr.r, nr) + _, tr.err = io.Copyn(ioutil.Discard, tr.r, nr) } func (tr *Reader) verifyChecksum(header []byte) bool { diff --git a/src/pkg/archive/zip/reader.go b/src/pkg/archive/zip/reader.go index 0391d6441..17464c5d8 100644 --- a/src/pkg/archive/zip/reader.go +++ b/src/pkg/archive/zip/reader.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The zip package provides support for reading ZIP archives. +Package zip provides support for reading ZIP archives. See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT diff --git a/src/pkg/asn1/asn1.go b/src/pkg/asn1/asn1.go index 8c99bd7a0..5f470aed7 100644 --- a/src/pkg/asn1/asn1.go +++ b/src/pkg/asn1/asn1.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The asn1 package implements parsing of DER-encoded ASN.1 data structures, +// Package asn1 implements parsing of DER-encoded ASN.1 data structures, // as defined in ITU-T Rec X.690. // // See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' @@ -418,13 +418,13 @@ func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type } var ( - bitStringType = reflect.Typeof(BitString{}) - objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) - enumeratedType = reflect.Typeof(Enumerated(0)) - flagType = reflect.Typeof(Flag(false)) - timeType = reflect.Typeof(&time.Time{}) - rawValueType = reflect.Typeof(RawValue{}) - rawContentsType = reflect.Typeof(RawContent(nil)) + bitStringType = reflect.TypeOf(BitString{}) + objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) + enumeratedType = reflect.TypeOf(Enumerated(0)) + flagType = reflect.TypeOf(Flag(false)) + timeType = reflect.TypeOf(&time.Time{}) + rawValueType = reflect.TypeOf(RawValue{}) + rawContentsType = reflect.TypeOf(RawContent(nil)) ) // invalidLength returns true iff offset + length > sliceLength, or if the @@ -461,7 +461,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} offset += t.length - v.Set(reflect.NewValue(result)) + v.Set(reflect.ValueOf(result)) return } @@ -505,7 +505,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam return } if result != nil { - v.Set(reflect.NewValue(result)) + v.Set(reflect.ValueOf(result)) } return } @@ -605,14 +605,14 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam newSlice, err1 := parseObjectIdentifier(innerBytes) v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) if err1 == nil { - reflect.Copy(v, reflect.NewValue(newSlice)) + reflect.Copy(v, reflect.ValueOf(newSlice)) } err = err1 return case bitStringType: bs, err1 := parseBitString(innerBytes) if err1 == nil { - v.Set(reflect.NewValue(bs)) + v.Set(reflect.ValueOf(bs)) } err = err1 return @@ -625,7 +625,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam time, err1 = parseGeneralizedTime(innerBytes) } if err1 == nil { - v.Set(reflect.NewValue(time)) + v.Set(reflect.ValueOf(time)) } err = err1 return @@ -671,7 +671,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam if structType.NumField() > 0 && structType.Field(0).Type == rawContentsType { bytes := bytes[initOffset:offset] - val.Field(0).Set(reflect.NewValue(RawContent(bytes))) + val.Field(0).Set(reflect.ValueOf(RawContent(bytes))) } innerOffset := 0 @@ -693,7 +693,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam sliceType := fieldType if sliceType.Elem().Kind() == reflect.Uint8 { val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) - reflect.Copy(val, reflect.NewValue(innerBytes)) + reflect.Copy(val, reflect.ValueOf(innerBytes)) return } newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) @@ -798,7 +798,7 @@ func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { // UnmarshalWithParams allows field parameters to be specified for the // top-level element. The form of the params is the same as the field tags. func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { - v := reflect.NewValue(val).Elem() + v := reflect.ValueOf(val).Elem() offset, err := parseField(v, b, 0, parseFieldParameters(params)) if err != nil { return nil, err diff --git a/src/pkg/asn1/asn1_test.go b/src/pkg/asn1/asn1_test.go index 018c534eb..78f562805 100644 --- a/src/pkg/asn1/asn1_test.go +++ b/src/pkg/asn1/asn1_test.go @@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) { } } -type unmarshalTest struct { - in []byte - out interface{} -} - type TestObjectIdentifierStruct struct { OID ObjectIdentifier } @@ -290,7 +285,10 @@ type TestElementsAfterString struct { A, B int } -var unmarshalTestData []unmarshalTest = []unmarshalTest{ +var unmarshalTestData = []struct { + in []byte + out interface{} +}{ {[]byte{0x02, 0x01, 0x42}, newInt(0x42)}, {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}}, {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}}, @@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{ func TestUnmarshal(t *testing.T) { for i, test := range unmarshalTestData { - pv := reflect.Zero(reflect.NewValue(test.out).Type()) - zv := reflect.Zero(pv.Type().Elem()) - pv.Set(zv.Addr()) + pv := reflect.New(reflect.TypeOf(test.out).Elem()) val := pv.Interface() _, err := Unmarshal(test.in, val) if err != nil { diff --git a/src/pkg/asn1/marshal.go b/src/pkg/asn1/marshal.go index 64cb0f2bb..a3e1145b8 100644 --- a/src/pkg/asn1/marshal.go +++ b/src/pkg/asn1/marshal.go @@ -493,7 +493,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) // Marshal returns the ASN.1 encoding of val. func Marshal(val interface{}) ([]byte, os.Error) { var out bytes.Buffer - v := reflect.NewValue(val) + v := reflect.ValueOf(val) f := newForkableWriter() err := marshalField(f, v, fieldParameters{}) if err != nil { diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go index a04d3b1d9..4848d427b 100755 --- a/src/pkg/big/nat.go +++ b/src/pkg/big/nat.go @@ -2,11 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This file contains operations on unsigned multi-precision integers. -// These are the building blocks for the operations on signed integers -// and rationals. - -// This package implements multi-precision arithmetic (big numbers). +// Package big implements multi-precision arithmetic (big numbers). // The following numeric types are supported: // // - Int signed integers @@ -18,6 +14,10 @@ // package big +// This file contains operations on unsigned multi-precision integers. +// These are the building blocks for the operations on signed integers +// and rationals. + import "rand" // An unsigned integer x of the form diff --git a/src/pkg/bufio/bufio.go b/src/pkg/bufio/bufio.go index 32a25afae..eaae8bb42 100644 --- a/src/pkg/bufio/bufio.go +++ b/src/pkg/bufio/bufio.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements buffered I/O. It wraps an io.Reader or io.Writer +// Package bufio implements buffered I/O. It wraps an io.Reader or io.Writer // object, creating another object (Reader or Writer) that also implements // the interface but provides buffering and some help for textual I/O. package bufio diff --git a/src/pkg/bytes/bytes.go b/src/pkg/bytes/bytes.go index c12a13573..0f9ac9863 100644 --- a/src/pkg/bytes/bytes.go +++ b/src/pkg/bytes/bytes.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The bytes package implements functions for the manipulation of byte slices. -// Analogous to the facilities of the strings package. +// Package bytes implements functions for the manipulation of byte slices. +// It is analogous to the facilities of the strings package. package bytes import ( diff --git a/src/pkg/cmath/abs.go b/src/pkg/cmath/abs.go index 725dc4e98..f3199cad5 100644 --- a/src/pkg/cmath/abs.go +++ b/src/pkg/cmath/abs.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The cmath package provides basic constants -// and mathematical functions for complex numbers. +// Package cmath provides basic constants and mathematical functions for +// complex numbers. package cmath import "math" diff --git a/src/pkg/compress/flate/deflate.go b/src/pkg/compress/flate/deflate.go index 591b35c44..e5b2beaef 100644 --- a/src/pkg/compress/flate/deflate.go +++ b/src/pkg/compress/flate/deflate.go @@ -477,6 +477,33 @@ func NewWriter(w io.Writer, level int) *Writer { return &Writer{pw, &d} } +// NewWriterDict is like NewWriter but initializes the new +// Writer with a preset dictionary. The returned Writer behaves +// as if the dictionary had been written to it without producing +// any compressed output. The compressed data written to w +// can only be decompressed by a Reader initialized with the +// same dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) *Writer { + dw := &dictWriter{w, false} + zw := NewWriter(dw, level) + zw.Write(dict) + zw.Flush() + dw.enabled = true + return zw +} + +type dictWriter struct { + w io.Writer + enabled bool +} + +func (w *dictWriter) Write(b []byte) (n int, err os.Error) { + if w.enabled { + return w.w.Write(b) + } + return len(b), nil +} + // A Writer takes data written to it and writes the compressed // form of that data to an underlying writer (see NewWriter). type Writer struct { diff --git a/src/pkg/compress/flate/deflate_test.go b/src/pkg/compress/flate/deflate_test.go index ed5884a4b..650a8059a 100644 --- a/src/pkg/compress/flate/deflate_test.go +++ b/src/pkg/compress/flate/deflate_test.go @@ -275,3 +275,49 @@ func TestDeflateInflateString(t *testing.T) { } testToFromWithLevel(t, 1, gold, "2.718281828...") } + +func TestReaderDict(t *testing.T) { + const ( + dict = "hello world" + text = "hello again world" + ) + var b bytes.Buffer + w := NewWriter(&b, 5) + w.Write([]byte(dict)) + w.Flush() + b.Reset() + w.Write([]byte(text)) + w.Close() + + r := NewReaderDict(&b, []byte(dict)) + data, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello again world" { + t.Fatalf("read returned %q want %q", string(data), text) + } +} + +func TestWriterDict(t *testing.T) { + const ( + dict = "hello world" + text = "hello again world" + ) + var b bytes.Buffer + w := NewWriter(&b, 5) + w.Write([]byte(dict)) + w.Flush() + b.Reset() + w.Write([]byte(text)) + w.Close() + + var b1 bytes.Buffer + w = NewWriterDict(&b1, 5, []byte(dict)) + w.Write([]byte(text)) + w.Close() + + if !bytes.Equal(b1.Bytes(), b.Bytes()) { + t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes()) + } +} diff --git a/src/pkg/compress/flate/inflate.go b/src/pkg/compress/flate/inflate.go index 7dc8cf93b..320b80d06 100644 --- a/src/pkg/compress/flate/inflate.go +++ b/src/pkg/compress/flate/inflate.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The flate package implements the DEFLATE compressed data -// format, described in RFC 1951. The gzip and zlib packages -// implement access to DEFLATE-based file formats. +// Package flate implements the DEFLATE compressed data format, described in +// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file +// formats. package flate import ( @@ -526,6 +526,20 @@ func (f *decompressor) dataBlock() os.Error { return nil } +func (f *decompressor) setDict(dict []byte) { + if len(dict) > len(f.hist) { + // Will only remember the tail. + dict = dict[len(dict)-len(f.hist):] + } + + f.hp = copy(f.hist[:], dict) + if f.hp == len(f.hist) { + f.hp = 0 + f.hfull = true + } + f.hw = f.hp +} + func (f *decompressor) moreBits() os.Error { c, err := f.r.ReadByte() if err != nil { @@ -618,3 +632,16 @@ func NewReader(r io.Reader) io.ReadCloser { go func() { pw.CloseWithError(f.decompress(r, pw)) }() return pr } + +// NewReaderDict is like NewReader but initializes the reader +// with a preset dictionary. The returned Reader behaves as if +// the uncompressed data stream started with the given dictionary, +// which has already been read. NewReaderDict is typically used +// to read data compressed by NewWriterDict. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + var f decompressor + f.setDict(dict) + pr, pw := io.Pipe() + go func() { pw.CloseWithError(f.decompress(r, pw)) }() + return pr +} diff --git a/src/pkg/compress/gzip/gunzip.go b/src/pkg/compress/gzip/gunzip.go index 3c0b3c5e5..b0ddc81d2 100644 --- a/src/pkg/compress/gzip/gunzip.go +++ b/src/pkg/compress/gzip/gunzip.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The gzip package implements reading and writing of -// gzip format compressed files, as specified in RFC 1952. +// Package gzip implements reading and writing of gzip format compressed files, +// as specified in RFC 1952. package gzip import ( diff --git a/src/pkg/compress/lzw/reader.go b/src/pkg/compress/lzw/reader.go index 8a540cbe6..d418bc856 100644 --- a/src/pkg/compress/lzw/reader.go +++ b/src/pkg/compress/lzw/reader.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The lzw package implements the Lempel-Ziv-Welch compressed data format, +// Package lzw implements the Lempel-Ziv-Welch compressed data format, // described in T. A. Welch, ``A Technique for High-Performance Data // Compression'', Computer, 17(6) (June 1984), pp 8-19. // diff --git a/src/pkg/compress/lzw/reader_test.go b/src/pkg/compress/lzw/reader_test.go index 4b5dfaade..72121a6b5 100644 --- a/src/pkg/compress/lzw/reader_test.go +++ b/src/pkg/compress/lzw/reader_test.go @@ -112,12 +112,6 @@ func TestReader(t *testing.T) { } } -type devNull struct{} - -func (devNull) Write(p []byte) (int, os.Error) { - return len(p), nil -} - func benchmarkDecoder(b *testing.B, n int) { b.StopTimer() b.SetBytes(int64(n)) @@ -134,7 +128,7 @@ func benchmarkDecoder(b *testing.B, n int) { runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { - io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8)) + io.Copy(ioutil.Discard, NewReader(bytes.NewBuffer(buf1), LSB, 8)) } } diff --git a/src/pkg/compress/lzw/writer_test.go b/src/pkg/compress/lzw/writer_test.go index e5815a03d..82464ecd1 100644 --- a/src/pkg/compress/lzw/writer_test.go +++ b/src/pkg/compress/lzw/writer_test.go @@ -113,7 +113,7 @@ func benchmarkEncoder(b *testing.B, n int) { runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { - w := NewWriter(devNull{}, LSB, 8) + w := NewWriter(ioutil.Discard, LSB, 8) w.Write(buf1) w.Close() } diff --git a/src/pkg/compress/zlib/reader.go b/src/pkg/compress/zlib/reader.go index 721f6ec55..8a3ef1580 100644 --- a/src/pkg/compress/zlib/reader.go +++ b/src/pkg/compress/zlib/reader.go @@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. /* -The zlib package implements reading and writing of zlib -format compressed data, as specified in RFC 1950. +Package zlib implements reading and writing of zlib format compressed data, +as specified in RFC 1950. The implementation provides filters that uncompress during reading and compress during writing. For example, to write compressed data @@ -36,7 +36,7 @@ const zlibDeflate = 8 var ChecksumError os.Error = os.ErrorString("zlib checksum error") var HeaderError os.Error = os.ErrorString("invalid zlib header") -var UnsupportedError os.Error = os.ErrorString("unsupported zlib format") +var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary") type reader struct { r flate.Reader @@ -50,6 +50,12 @@ type reader struct { // The implementation buffers input and may read more data than necessary from r. // It is the caller's responsibility to call Close on the ReadCloser when done. func NewReader(r io.Reader) (io.ReadCloser, os.Error) { + return NewReaderDict(r, nil) +} + +// NewReaderDict is like NewReader but uses a preset dictionary. +// NewReaderDict ignores the dictionary if the compressed data does not refer to it. +func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, os.Error) { z := new(reader) if fr, ok := r.(flate.Reader); ok { z.r = fr @@ -65,11 +71,19 @@ func NewReader(r io.Reader) (io.ReadCloser, os.Error) { return nil, HeaderError } if z.scratch[1]&0x20 != 0 { - // BUG(nigeltao): The zlib package does not implement the FDICT flag. - return nil, UnsupportedError + _, err = io.ReadFull(z.r, z.scratch[0:4]) + if err != nil { + return nil, err + } + checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) + if checksum != adler32.Checksum(dict) { + return nil, DictionaryError + } + z.decompressor = flate.NewReaderDict(z.r, dict) + } else { + z.decompressor = flate.NewReader(z.r) } z.digest = adler32.New() - z.decompressor = flate.NewReader(z.r) return z, nil } diff --git a/src/pkg/compress/zlib/reader_test.go b/src/pkg/compress/zlib/reader_test.go index eaefc3a36..195db446c 100644 --- a/src/pkg/compress/zlib/reader_test.go +++ b/src/pkg/compress/zlib/reader_test.go @@ -15,6 +15,7 @@ type zlibTest struct { desc string raw string compressed []byte + dict []byte err os.Error } @@ -27,6 +28,7 @@ var zlibTests = []zlibTest{ "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, nil, + nil, }, { "goodbye", @@ -37,23 +39,27 @@ var zlibTests = []zlibTest{ 0x01, 0x00, 0x28, 0xa5, 0x05, 0x5e, }, nil, + nil, }, { "bad header", "", []byte{0x78, 0x9f, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, + nil, HeaderError, }, { "bad checksum", "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0xff}, + nil, ChecksumError, }, { "not enough data", "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00}, + nil, io.ErrUnexpectedEOF, }, { @@ -64,6 +70,33 @@ var zlibTests = []zlibTest{ 0x78, 0x9c, 0xff, }, nil, + nil, + }, + { + "dictionary", + "Hello, World!\n", + []byte{ + 0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00, + 0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24, + 0x12, 0x04, 0x74, + }, + []byte{ + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x0a, + }, + nil, + }, + { + "wrong dictionary", + "", + []byte{ + 0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00, + 0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24, + 0x12, 0x04, 0x74, + }, + []byte{ + 0x48, 0x65, 0x6c, 0x6c, + }, + DictionaryError, }, } @@ -71,7 +104,7 @@ func TestDecompressor(t *testing.T) { b := new(bytes.Buffer) for _, tt := range zlibTests { in := bytes.NewBuffer(tt.compressed) - zlib, err := NewReader(in) + zlib, err := NewReaderDict(in, tt.dict) if err != nil { if err != tt.err { t.Errorf("%s: NewReader: %s", tt.desc, err) diff --git a/src/pkg/compress/zlib/writer.go b/src/pkg/compress/zlib/writer.go index 031586cd2..f1f9b2853 100644 --- a/src/pkg/compress/zlib/writer.go +++ b/src/pkg/compress/zlib/writer.go @@ -21,56 +21,80 @@ const ( DefaultCompression = flate.DefaultCompression ) -type writer struct { +// A Writer takes data written to it and writes the compressed +// form of that data to an underlying writer (see NewWriter). +type Writer struct { w io.Writer - compressor io.WriteCloser + compressor *flate.Writer digest hash.Hash32 err os.Error scratch [4]byte } // NewWriter calls NewWriterLevel with the default compression level. -func NewWriter(w io.Writer) (io.WriteCloser, os.Error) { +func NewWriter(w io.Writer) (*Writer, os.Error) { return NewWriterLevel(w, DefaultCompression) } -// NewWriterLevel creates a new io.WriteCloser that satisfies writes by compressing data written to w. +// NewWriterLevel calls NewWriterDict with no dictionary. +func NewWriterLevel(w io.Writer, level int) (*Writer, os.Error) { + return NewWriterDict(w, level, nil) +} + +// NewWriterDict creates a new io.WriteCloser that satisfies writes by compressing data written to w. // It is the caller's responsibility to call Close on the WriteCloser when done. // level is the compression level, which can be DefaultCompression, NoCompression, // or any integer value between BestSpeed and BestCompression (inclusive). -func NewWriterLevel(w io.Writer, level int) (io.WriteCloser, os.Error) { - z := new(writer) +// dict is the preset dictionary to compress with, or nil to use no dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) { + z := new(Writer) // ZLIB has a two-byte header (as documented in RFC 1950). // The first four bits is the CINFO (compression info), which is 7 for the default deflate window size. // The next four bits is the CM (compression method), which is 8 for deflate. z.scratch[0] = 0x78 // The next two bits is the FLEVEL (compression level). The four values are: // 0=fastest, 1=fast, 2=default, 3=best. - // The next bit, FDICT, is unused, in this implementation. + // The next bit, FDICT, is set if a dictionary is given. // The final five FCHECK bits form a mod-31 checksum. switch level { case 0, 1: - z.scratch[1] = 0x01 + z.scratch[1] = 0 << 6 case 2, 3, 4, 5: - z.scratch[1] = 0x5e + z.scratch[1] = 1 << 6 case 6, -1: - z.scratch[1] = 0x9c + z.scratch[1] = 2 << 6 case 7, 8, 9: - z.scratch[1] = 0xda + z.scratch[1] = 3 << 6 default: return nil, os.NewError("level out of range") } + if dict != nil { + z.scratch[1] |= 1 << 5 + } + z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31) _, err := w.Write(z.scratch[0:2]) if err != nil { return nil, err } + if dict != nil { + // The next four bytes are the Adler-32 checksum of the dictionary. + checksum := adler32.Checksum(dict) + z.scratch[0] = uint8(checksum >> 24) + z.scratch[1] = uint8(checksum >> 16) + z.scratch[2] = uint8(checksum >> 8) + z.scratch[3] = uint8(checksum >> 0) + _, err = w.Write(z.scratch[0:4]) + if err != nil { + return nil, err + } + } z.w = w z.compressor = flate.NewWriter(w, level) z.digest = adler32.New() return z, nil } -func (z *writer) Write(p []byte) (n int, err os.Error) { +func (z *Writer) Write(p []byte) (n int, err os.Error) { if z.err != nil { return 0, z.err } @@ -86,8 +110,17 @@ func (z *writer) Write(p []byte) (n int, err os.Error) { return } +// Flush flushes the underlying compressor. +func (z *Writer) Flush() os.Error { + if z.err != nil { + return z.err + } + z.err = z.compressor.Flush() + return z.err +} + // Calling Close does not close the wrapped io.Writer originally passed to NewWriter. -func (z *writer) Close() os.Error { +func (z *Writer) Close() os.Error { if z.err != nil { return z.err } diff --git a/src/pkg/compress/zlib/writer_test.go b/src/pkg/compress/zlib/writer_test.go index 7eb1cd494..f94f28470 100644 --- a/src/pkg/compress/zlib/writer_test.go +++ b/src/pkg/compress/zlib/writer_test.go @@ -16,13 +16,19 @@ var filenames = []string{ "../testdata/pi.txt", } -// Tests that compressing and then decompressing the given file at the given compression level +// Tests that compressing and then decompressing the given file at the given compression level and dictionary // yields equivalent bytes to the original file. -func testFileLevel(t *testing.T, fn string, level int) { +func testFileLevelDict(t *testing.T, fn string, level int, d string) { + // Read dictionary, if given. + var dict []byte + if d != "" { + dict = []byte(d) + } + // Read the file, as golden output. golden, err := os.Open(fn) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer golden.Close() @@ -30,7 +36,7 @@ func testFileLevel(t *testing.T, fn string, level int) { // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end. raw, err := os.Open(fn) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } piper, pipew := io.Pipe() @@ -38,9 +44,9 @@ func testFileLevel(t *testing.T, fn string, level int) { go func() { defer raw.Close() defer pipew.Close() - zlibw, err := NewWriterLevel(pipew, level) + zlibw, err := NewWriterDict(pipew, level, dict) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer zlibw.Close() @@ -48,7 +54,7 @@ func testFileLevel(t *testing.T, fn string, level int) { for { n, err0 := raw.Read(b[0:]) if err0 != nil && err0 != os.EOF { - t.Errorf("%s (level=%d): %v", fn, level, err0) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) return } _, err1 := zlibw.Write(b[0:n]) @@ -57,7 +63,7 @@ func testFileLevel(t *testing.T, fn string, level int) { return } if err1 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err1) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) return } if err0 == os.EOF { @@ -65,9 +71,9 @@ func testFileLevel(t *testing.T, fn string, level int) { } } }() - zlibr, err := NewReader(piper) + zlibr, err := NewReaderDict(piper, dict) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer zlibr.Close() @@ -76,20 +82,20 @@ func testFileLevel(t *testing.T, fn string, level int) { b0, err0 := ioutil.ReadAll(golden) b1, err1 := ioutil.ReadAll(zlibr) if err0 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err0) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) return } if err1 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err1) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) return } if len(b0) != len(b1) { - t.Errorf("%s (level=%d): length mismatch %d versus %d", fn, level, len(b0), len(b1)) + t.Errorf("%s (level=%d, dict=%q): length mismatch %d versus %d", fn, level, d, len(b0), len(b1)) return } for i := 0; i < len(b0); i++ { if b0[i] != b1[i] { - t.Errorf("%s (level=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, i, b0[i], b1[i]) + t.Errorf("%s (level=%d, dict=%q): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, d, i, b0[i], b1[i]) return } } @@ -97,10 +103,21 @@ func testFileLevel(t *testing.T, fn string, level int) { func TestWriter(t *testing.T) { for _, fn := range filenames { - testFileLevel(t, fn, DefaultCompression) - testFileLevel(t, fn, NoCompression) + testFileLevelDict(t, fn, DefaultCompression, "") + testFileLevelDict(t, fn, NoCompression, "") + for level := BestSpeed; level <= BestCompression; level++ { + testFileLevelDict(t, fn, level, "") + } + } +} + +func TestWriterDict(t *testing.T) { + const dictionary = "0123456789." + for _, fn := range filenames { + testFileLevelDict(t, fn, DefaultCompression, dictionary) + testFileLevelDict(t, fn, NoCompression, dictionary) for level := BestSpeed; level <= BestCompression; level++ { - testFileLevel(t, fn, level) + testFileLevelDict(t, fn, level, dictionary) } } } diff --git a/src/pkg/container/heap/heap.go b/src/pkg/container/heap/heap.go index 4435a57c4..f2b8a750a 100644 --- a/src/pkg/container/heap/heap.go +++ b/src/pkg/container/heap/heap.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides heap operations for any type that implements +// Package heap provides heap operations for any type that implements // heap.Interface. // package heap diff --git a/src/pkg/container/heap/heap_test.go b/src/pkg/container/heap/heap_test.go index 89d444dd5..5eb54374a 100644 --- a/src/pkg/container/heap/heap_test.go +++ b/src/pkg/container/heap/heap_test.go @@ -2,11 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package heap +package heap_test import ( "testing" "container/vector" + . "container/heap" ) diff --git a/src/pkg/container/list/list.go b/src/pkg/container/list/list.go index c1ebcddaa..a3fd4b39f 100644 --- a/src/pkg/container/list/list.go +++ b/src/pkg/container/list/list.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The list package implements a doubly linked list. +// Package list implements a doubly linked list. // // To iterate over a list (where l is a *List): // for e := l.Front(); e != nil; e = e.Next() { diff --git a/src/pkg/container/ring/ring.go b/src/pkg/container/ring/ring.go index 5925164e9..cc870ce93 100644 --- a/src/pkg/container/ring/ring.go +++ b/src/pkg/container/ring/ring.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The ring package implements operations on circular lists. +// Package ring implements operations on circular lists. package ring // A Ring is an element of a circular list, or ring. diff --git a/src/pkg/container/vector/defs.go b/src/pkg/container/vector/defs.go index a2febb6de..bfb5481fb 100644 --- a/src/pkg/container/vector/defs.go +++ b/src/pkg/container/vector/defs.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The vector package implements containers for managing sequences -// of elements. Vectors grow and shrink dynamically as necessary. +// Package vector implements containers for managing sequences of elements. +// Vectors grow and shrink dynamically as necessary. package vector diff --git a/src/pkg/crypto/aes/const.go b/src/pkg/crypto/aes/const.go index 97a5b64ec..25acd0d17 100644 --- a/src/pkg/crypto/aes/const.go +++ b/src/pkg/crypto/aes/const.go @@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// AES constants - 8720 bytes of initialized data. - -// This package implements AES encryption (formerly Rijndael), -// as defined in U.S. Federal Information Processing Standards Publication 197. +// Package aes implements AES encryption (formerly Rijndael), as defined in +// U.S. Federal Information Processing Standards Publication 197. package aes +// This file contains AES constants - 8720 bytes of initialized data. + // http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf // AES is based on the mathematical behavior of binary polynomials diff --git a/src/pkg/crypto/blowfish/cipher.go b/src/pkg/crypto/blowfish/cipher.go index 947f762d8..f3c5175ac 100644 --- a/src/pkg/crypto/blowfish/cipher.go +++ b/src/pkg/crypto/blowfish/cipher.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements Bruce Schneier's Blowfish encryption algorithm. +// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. package blowfish // The code is a port of Bruce Schneier's C implementation. diff --git a/src/pkg/crypto/cast5/cast5.go b/src/pkg/crypto/cast5/cast5.go index 35f3e64b6..cb62e3132 100644 --- a/src/pkg/crypto/cast5/cast5.go +++ b/src/pkg/crypto/cast5/cast5.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements CAST5, as defined in RFC 2144. CAST5 is a common +// Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common // OpenPGP cipher. package cast5 diff --git a/src/pkg/crypto/cipher/cipher.go b/src/pkg/crypto/cipher/cipher.go index 50516b23a..1ffaa8c2c 100644 --- a/src/pkg/crypto/cipher/cipher.go +++ b/src/pkg/crypto/cipher/cipher.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The cipher package implements standard block cipher modes -// that can be wrapped around low-level block cipher implementations. +// Package cipher implements standard block cipher modes that can be wrapped +// around low-level block cipher implementations. // See http://csrc.nist.gov/groups/ST/toolkit/BCM/current_modes.html // and NIST Special Publication 800-38A. package cipher diff --git a/src/pkg/crypto/crypto.go b/src/pkg/crypto/crypto.go index be6b34adf..53672a4da 100644 --- a/src/pkg/crypto/crypto.go +++ b/src/pkg/crypto/crypto.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The crypto package collects common cryptographic constants. +// Package crypto collects common cryptographic constants. package crypto import ( diff --git a/src/pkg/crypto/elliptic/elliptic.go b/src/pkg/crypto/elliptic/elliptic.go index 2296e9607..335c9645d 100644 --- a/src/pkg/crypto/elliptic/elliptic.go +++ b/src/pkg/crypto/elliptic/elliptic.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The elliptic package implements several standard elliptic curves over prime -// fields +// Package elliptic implements several standard elliptic curves over prime +// fields. package elliptic // This package operates, internally, on Jacobian coordinates. For a given diff --git a/src/pkg/crypto/hmac/hmac.go b/src/pkg/crypto/hmac/hmac.go index 298fb2c06..04ec86e9a 100644 --- a/src/pkg/crypto/hmac/hmac.go +++ b/src/pkg/crypto/hmac/hmac.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The hmac package implements the Keyed-Hash Message Authentication Code (HMAC) -// as defined in U.S. Federal Information Processing Standards Publication 198. +// Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as +// defined in U.S. Federal Information Processing Standards Publication 198. // An HMAC is a cryptographic hash that uses a key to sign a message. // The receiver verifies the hash by recomputing it using the same key. package hmac diff --git a/src/pkg/crypto/md4/md4.go b/src/pkg/crypto/md4/md4.go index ee46544a9..848d9552d 100644 --- a/src/pkg/crypto/md4/md4.go +++ b/src/pkg/crypto/md4/md4.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the MD4 hash algorithm as defined in RFC 1320. +// Package md4 implements the MD4 hash algorithm as defined in RFC 1320. package md4 import ( diff --git a/src/pkg/crypto/md5/md5.go b/src/pkg/crypto/md5/md5.go index 8f93fc4b3..378faa6ec 100644 --- a/src/pkg/crypto/md5/md5.go +++ b/src/pkg/crypto/md5/md5.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the MD5 hash algorithm as defined in RFC 1321. +// Package md5 implements the MD5 hash algorithm as defined in RFC 1321. package md5 import ( diff --git a/src/pkg/crypto/ocsp/ocsp.go b/src/pkg/crypto/ocsp/ocsp.go index f42d80888..acd75b8b0 100644 --- a/src/pkg/crypto/ocsp/ocsp.go +++ b/src/pkg/crypto/ocsp/ocsp.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package parses OCSP responses as specified in RFC 2560. OCSP responses +// Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses // are signed messages attesting to the validity of a certificate for a small // period of time. This is used to manage revocation for X.509 certificates. package ocsp diff --git a/src/pkg/crypto/openpgp/armor/armor.go b/src/pkg/crypto/openpgp/armor/armor.go index d695a8c33..8da612c50 100644 --- a/src/pkg/crypto/openpgp/armor/armor.go +++ b/src/pkg/crypto/openpgp/armor/armor.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is +// Package armor implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is // very similar to PEM except that it has an additional CRC checksum. package armor diff --git a/src/pkg/crypto/openpgp/error/error.go b/src/pkg/crypto/openpgp/error/error.go index 053d15967..3759ce161 100644 --- a/src/pkg/crypto/openpgp/error/error.go +++ b/src/pkg/crypto/openpgp/error/error.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package contains common error types for the OpenPGP packages. +// Package error contains common error types for the OpenPGP packages. package error import ( diff --git a/src/pkg/crypto/openpgp/keys.go b/src/pkg/crypto/openpgp/keys.go index ecaa86f28..6c03f8828 100644 --- a/src/pkg/crypto/openpgp/keys.go +++ b/src/pkg/crypto/openpgp/keys.go @@ -5,6 +5,7 @@ package openpgp import ( + "crypto/openpgp/armor" "crypto/openpgp/error" "crypto/openpgp/packet" "io" @@ -13,6 +14,8 @@ import ( // PublicKeyType is the armor type for a PGP public key. var PublicKeyType = "PGP PUBLIC KEY BLOCK" +// PrivateKeyType is the armor type for a PGP private key. +var PrivateKeyType = "PGP PRIVATE KEY BLOCK" // An Entity represents the components of an OpenPGP key: a primary public key // (which must be a signing key), one or more identities claimed by that key, @@ -101,37 +104,50 @@ func (el EntityList) DecryptionKeys() (keys []Key) { // ReadArmoredKeyRing reads one or more public/private keys from an armor keyring file. func ReadArmoredKeyRing(r io.Reader) (EntityList, os.Error) { - body, err := readArmored(r, PublicKeyType) + block, err := armor.Decode(r) + if err == os.EOF { + return nil, error.InvalidArgumentError("no armored data found") + } if err != nil { return nil, err } + if block.Type != PublicKeyType && block.Type != PrivateKeyType { + return nil, error.InvalidArgumentError("expected public or private key block, got: " + block.Type) + } - return ReadKeyRing(body) + return ReadKeyRing(block.Body) } -// ReadKeyRing reads one or more public/private keys, ignoring unsupported keys. +// ReadKeyRing reads one or more public/private keys. Unsupported keys are +// ignored as long as at least a single valid key is found. func ReadKeyRing(r io.Reader) (el EntityList, err os.Error) { packets := packet.NewReader(r) + var lastUnsupportedError os.Error for { var e *Entity e, err = readEntity(packets) if err != nil { if _, ok := err.(error.UnsupportedError); ok { + lastUnsupportedError = err err = readToNextPublicKey(packets) } if err == os.EOF { err = nil - return + break } if err != nil { el = nil - return + break } } else { el = append(el, e) } } + + if len(el) == 0 && err == nil { + err = lastUnsupportedError + } return } @@ -197,25 +213,28 @@ EachPacket: current.Name = pkt.Id current.UserId = pkt e.Identities[pkt.Id] = current - p, err = packets.Next() - if err == os.EOF { - err = io.ErrUnexpectedEOF - } - if err != nil { - if _, ok := err.(error.UnsupportedError); ok { + + for { + p, err = packets.Next() + if err == os.EOF { + return nil, io.ErrUnexpectedEOF + } else if err != nil { return nil, err } - return nil, error.StructuralError("identity self-signature invalid: " + err.String()) - } - current.SelfSignature, ok = p.(*packet.Signature) - if !ok { - return nil, error.StructuralError("user ID packet not followed by self signature") - } - if current.SelfSignature.SigType != packet.SigTypePositiveCert { - return nil, error.StructuralError("user ID self-signature with wrong type") - } - if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, current.SelfSignature); err != nil { - return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) + + sig, ok := p.(*packet.Signature) + if !ok { + return nil, error.StructuralError("user ID packet not followed by self-signature") + } + + if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId { + if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil { + return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) + } + current.SelfSignature = sig + break + } + current.Signatures = append(current.Signatures, sig) } case *packet.Signature: if current == nil { diff --git a/src/pkg/crypto/openpgp/packet/packet.go b/src/pkg/crypto/openpgp/packet/packet.go index 57ff3afbf..c0ec44dd8 100644 --- a/src/pkg/crypto/openpgp/packet/packet.go +++ b/src/pkg/crypto/openpgp/packet/packet.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements parsing and serialisation of OpenPGP packets, as +// Package packet implements parsing and serialisation of OpenPGP packets, as // specified in RFC 4880. package packet diff --git a/src/pkg/crypto/openpgp/packet/private_key.go b/src/pkg/crypto/openpgp/packet/private_key.go index 694482390..fde2a9933 100644 --- a/src/pkg/crypto/openpgp/packet/private_key.go +++ b/src/pkg/crypto/openpgp/packet/private_key.go @@ -164,8 +164,10 @@ func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) { } rsaPriv.D = new(big.Int).SetBytes(d) - rsaPriv.P = new(big.Int).SetBytes(p) - rsaPriv.Q = new(big.Int).SetBytes(q) + rsaPriv.Primes = make([]*big.Int, 2) + rsaPriv.Primes[0] = new(big.Int).SetBytes(p) + rsaPriv.Primes[1] = new(big.Int).SetBytes(q) + rsaPriv.Precompute() pk.PrivateKey = rsaPriv pk.Encrypted = false pk.encryptedData = nil diff --git a/src/pkg/crypto/openpgp/packet/public_key.go b/src/pkg/crypto/openpgp/packet/public_key.go index ebef481fb..cd4a9aebb 100644 --- a/src/pkg/crypto/openpgp/packet/public_key.go +++ b/src/pkg/crypto/openpgp/packet/public_key.go @@ -15,6 +15,7 @@ import ( "hash" "io" "os" + "strconv" ) // PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2. @@ -47,7 +48,7 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) { case PubKeyAlgoDSA: err = pk.parseDSA(r) default: - err = error.UnsupportedError("public key type") + err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) } if err != nil { return diff --git a/src/pkg/crypto/openpgp/read.go b/src/pkg/crypto/openpgp/read.go index ac6998f0d..4f84dff82 100644 --- a/src/pkg/crypto/openpgp/read.go +++ b/src/pkg/crypto/openpgp/read.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This openpgp package implements high level operations on OpenPGP messages. +// Package openpgp implements high level operations on OpenPGP messages. package openpgp import ( diff --git a/src/pkg/crypto/openpgp/read_test.go b/src/pkg/crypto/openpgp/read_test.go index 6218d9990..423c85b0f 100644 --- a/src/pkg/crypto/openpgp/read_test.go +++ b/src/pkg/crypto/openpgp/read_test.go @@ -230,6 +230,23 @@ func TestDetachedSignatureDSA(t *testing.T) { testDetachedSignature(t, kring, readerFromHex(detachedSignatureDSAHex), signedInput, "binary", testKey3KeyId) } +func TestReadingArmoredPrivateKey(t *testing.T) { + el, err := ReadArmoredKeyRing(bytes.NewBufferString(armoredPrivateKeyBlock)) + if err != nil { + t.Error(err) + } + if len(el) != 1 { + t.Errorf("got %d entities, wanted 1\n", len(el)) + } +} + +func TestNoArmoredData(t *testing.T) { + _, err := ReadArmoredKeyRing(bytes.NewBufferString("foo")) + if _, ok := err.(error.InvalidArgumentError); !ok { + t.Errorf("error was not an InvalidArgumentError: %s", err) + } +} + const testKey1KeyId = 0xA34D7E18C20C31BB const testKey3KeyId = 0x338934250CCC0360 @@ -259,3 +276,37 @@ const symmetricallyEncryptedCompressedHex = "8c0d04030302eb4a03808145d0d260c92f7 const dsaTestKeyHex = "9901a2044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" const dsaTestKeyPrivateHex = "9501bb044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4d00009f592e0619d823953577d4503061706843317e4fee083db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" + +const armoredPrivateKeyBlock = `-----BEGIN PGP PRIVATE KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +lQHYBE2rFNoBBADFwqWQIW/DSqcB4yCQqnAFTJ27qS5AnB46ccAdw3u4Greeu3Bp +idpoHdjULy7zSKlwR1EA873dO/k/e11Ml3dlAFUinWeejWaK2ugFP6JjiieSsrKn +vWNicdCS4HTWn0X4sjl0ZiAygw6GNhqEQ3cpLeL0g8E9hnYzJKQ0LWJa0QARAQAB +AAP/TB81EIo2VYNmTq0pK1ZXwUpxCrvAAIG3hwKjEzHcbQznsjNvPUihZ+NZQ6+X +0HCfPAdPkGDCLCb6NavcSW+iNnLTrdDnSI6+3BbIONqWWdRDYJhqZCkqmG6zqSfL +IdkJgCw94taUg5BWP/AAeQrhzjChvpMQTVKQL5mnuZbUCeMCAN5qrYMP2S9iKdnk +VANIFj7656ARKt/nf4CBzxcpHTyB8+d2CtPDKCmlJP6vL8t58Jmih+kHJMvC0dzn +gr5f5+sCAOOe5gt9e0am7AvQWhdbHVfJU0TQJx+m2OiCJAqGTB1nvtBLHdJnfdC9 +TnXXQ6ZXibqLyBies/xeY2sCKL5qtTMCAKnX9+9d/5yQxRyrQUHt1NYhaXZnJbHx +q4ytu0eWz+5i68IYUSK69jJ1NWPM0T6SkqpB3KCAIv68VFm9PxqG1KmhSrQIVGVz +dCBLZXmIuAQTAQIAIgUCTasU2gIbAwYLCQgHAwIGFQgCCQoLBBYCAwECHgECF4AA +CgkQO9o98PRieSoLhgQAkLEZex02Qt7vGhZzMwuN0R22w3VwyYyjBx+fM3JFETy1 +ut4xcLJoJfIaF5ZS38UplgakHG0FQ+b49i8dMij0aZmDqGxrew1m4kBfjXw9B/v+ +eIqpODryb6cOSwyQFH0lQkXC040pjq9YqDsO5w0WYNXYKDnzRV0p4H1pweo2VDid +AdgETasU2gEEAN46UPeWRqKHvA99arOxee38fBt2CI08iiWyI8T3J6ivtFGixSqV +bRcPxYO/qLpVe5l84Nb3X71GfVXlc9hyv7CD6tcowL59hg1E/DC5ydI8K8iEpUmK +/UnHdIY5h8/kqgGxkY/T/hgp5fRQgW1ZoZxLajVlMRZ8W4tFtT0DeA+JABEBAAEA +A/0bE1jaaZKj6ndqcw86jd+QtD1SF+Cf21CWRNeLKnUds4FRRvclzTyUMuWPkUeX +TaNNsUOFqBsf6QQ2oHUBBK4VCHffHCW4ZEX2cd6umz7mpHW6XzN4DECEzOVksXtc +lUC1j4UB91DC/RNQqwX1IV2QLSwssVotPMPqhOi0ZLNY7wIA3n7DWKInxYZZ4K+6 +rQ+POsz6brEoRHwr8x6XlHenq1Oki855pSa1yXIARoTrSJkBtn5oI+f8AzrnN0BN +oyeQAwIA/7E++3HDi5aweWrViiul9cd3rcsS0dEnksPhvS0ozCJiHsq/6GFmy7J8 +QSHZPteedBnZyNp5jR+H7cIfVN3KgwH/Skq4PsuPhDq5TKK6i8Pc1WW8MA6DXTdU +nLkX7RGmMwjC0DBf7KWAlPjFaONAX3a8ndnz//fy1q7u2l9AZwrj1qa1iJ8EGAEC +AAkFAk2rFNoCGwwACgkQO9o98PRieSo2/QP/WTzr4ioINVsvN1akKuekmEMI3LAp +BfHwatufxxP1U+3Si/6YIk7kuPB9Hs+pRqCXzbvPRrI8NHZBmc8qIGthishdCYad +AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL +VrM0m72/jnpKo04= +=zNCn +-----END PGP PRIVATE KEY BLOCK-----` diff --git a/src/pkg/crypto/openpgp/s2k/s2k.go b/src/pkg/crypto/openpgp/s2k/s2k.go index 873b33dc0..93b7582fa 100644 --- a/src/pkg/crypto/openpgp/s2k/s2k.go +++ b/src/pkg/crypto/openpgp/s2k/s2k.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the various OpenPGP string-to-key transforms as +// Package s2k implements the various OpenPGP string-to-key transforms as // specified in RFC 4800 section 3.7.1. package s2k diff --git a/src/pkg/crypto/rc4/rc4.go b/src/pkg/crypto/rc4/rc4.go index 65fd195f3..7ee471093 100644 --- a/src/pkg/crypto/rc4/rc4.go +++ b/src/pkg/crypto/rc4/rc4.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements RC4 encryption, as defined in Bruce Schneier's +// Package rc4 implements RC4 encryption, as defined in Bruce Schneier's // Applied Cryptography. package rc4 diff --git a/src/pkg/crypto/ripemd160/ripemd160.go b/src/pkg/crypto/ripemd160/ripemd160.go index 6e88521c3..5aaca59a3 100644 --- a/src/pkg/crypto/ripemd160/ripemd160.go +++ b/src/pkg/crypto/ripemd160/ripemd160.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the RIPEMD-160 hash algorithm. +// Package ripemd160 implements the RIPEMD-160 hash algorithm. package ripemd160 // RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart diff --git a/src/pkg/crypto/rsa/pkcs1v15_test.go b/src/pkg/crypto/rsa/pkcs1v15_test.go index 30a4824a6..d69bacfd6 100644 --- a/src/pkg/crypto/rsa/pkcs1v15_test.go +++ b/src/pkg/crypto/rsa/pkcs1v15_test.go @@ -197,12 +197,6 @@ func TestVerifyPKCS1v15(t *testing.T) { } } -func bigFromString(s string) *big.Int { - ret := new(big.Int) - ret.SetString(s, 10) - return ret -} - // In order to generate new test vectors you'll need the PEM form of this key: // -----BEGIN RSA PRIVATE KEY----- // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 @@ -216,10 +210,12 @@ func bigFromString(s string) *big.Int { var rsaPrivateKey = &PrivateKey{ PublicKey: PublicKey{ - N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), + N: fromBase10("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), E: 65537, }, - D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + D: fromBase10("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), + Primes: []*big.Int{ + fromBase10("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + fromBase10("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + }, } diff --git a/src/pkg/crypto/rsa/rsa.go b/src/pkg/crypto/rsa/rsa.go index b3b212c20..e1813dbf9 100644 --- a/src/pkg/crypto/rsa/rsa.go +++ b/src/pkg/crypto/rsa/rsa.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements RSA encryption as specified in PKCS#1. +// Package rsa implements RSA encryption as specified in PKCS#1. package rsa // TODO(agl): Add support for PSS padding. @@ -13,7 +13,6 @@ import ( "hash" "io" "os" - "sync" ) var bigZero = big.NewInt(0) @@ -90,50 +89,60 @@ type PublicKey struct { // A PrivateKey represents an RSA key type PrivateKey struct { - PublicKey // public part. - D *big.Int // private exponent - P, Q, R *big.Int // prime factors of N (R may be nil) - - rwMutex sync.RWMutex // protects the following - dP, dQ, dR *big.Int // D mod (P-1) (or mod Q-1 etc) - qInv *big.Int // q^-1 mod p - pq *big.Int // P*Q - tr *big.Int // pq·tr ≡ 1 mod r + PublicKey // public part. + D *big.Int // private exponent + Primes []*big.Int // prime factors of N, has >= 2 elements. + + // Precomputed contains precomputed values that speed up private + // operations, if availible. + Precomputed PrecomputedValues +} + +type PrecomputedValues struct { + Dp, Dq *big.Int // D mod (P-1) (or mod Q-1) + Qinv *big.Int // Q^-1 mod Q + + // CRTValues is used for the 3rd and subsequent primes. Due to a + // historical accident, the CRT for the first two primes is handled + // differently in PKCS#1 and interoperability is sufficiently + // important that we mirror this. + CRTValues []CRTValue +} + +// CRTValue contains the precomputed chinese remainder theorem values. +type CRTValue struct { + Exp *big.Int // D mod (prime-1). + Coeff *big.Int // R·Coeff ≡ 1 mod Prime. + R *big.Int // product of primes prior to this (inc p and q). } // Validate performs basic sanity checks on the key. // It returns nil if the key is valid, or else an os.Error describing a problem. func (priv *PrivateKey) Validate() os.Error { - // Check that p, q and, maybe, r are prime. Note that this is just a - // sanity check. Since the random witnesses chosen by ProbablyPrime are - // deterministic, given the candidate number, it's easy for an attack - // to generate composites that pass this test. - if !big.ProbablyPrime(priv.P, 20) { - return os.ErrorString("P is composite") - } - if !big.ProbablyPrime(priv.Q, 20) { - return os.ErrorString("Q is composite") - } - if priv.R != nil && !big.ProbablyPrime(priv.R, 20) { - return os.ErrorString("R is composite") + // Check that the prime factors are actually prime. Note that this is + // just a sanity check. Since the random witnesses chosen by + // ProbablyPrime are deterministic, given the candidate number, it's + // easy for an attack to generate composites that pass this test. + for _, prime := range priv.Primes { + if !big.ProbablyPrime(prime, 20) { + return os.ErrorString("Prime factor is composite") + } } - // Check that p*q*r == n. - modulus := new(big.Int).Mul(priv.P, priv.Q) - if priv.R != nil { - modulus.Mul(modulus, priv.R) + // Check that Πprimes == n. + modulus := new(big.Int).Set(bigOne) + for _, prime := range priv.Primes { + modulus.Mul(modulus, prime) } if modulus.Cmp(priv.N) != 0 { return os.ErrorString("invalid modulus") } - // Check that e and totient(p, q, r) are coprime. - pminus1 := new(big.Int).Sub(priv.P, bigOne) - qminus1 := new(big.Int).Sub(priv.Q, bigOne) - totient := new(big.Int).Mul(pminus1, qminus1) - if priv.R != nil { - rminus1 := new(big.Int).Sub(priv.R, bigOne) - totient.Mul(totient, rminus1) + // Check that e and totient(Πprimes) are coprime. + totient := new(big.Int).Set(bigOne) + for _, prime := range priv.Primes { + pminus1 := new(big.Int).Sub(prime, bigOne) + totient.Mul(totient, pminus1) } e := big.NewInt(int64(priv.E)) gcd := new(big.Int) @@ -143,7 +152,7 @@ func (priv *PrivateKey) Validate() os.Error { if gcd.Cmp(bigOne) != 0 { return os.ErrorString("invalid public exponent E") } - // Check that de ≡ 1 (mod totient(p, q, r)) + // Check that de ≡ 1 (mod totient(Πprimes)) de := new(big.Int).Mul(priv.D, e) de.Mod(de, totient) if de.Cmp(bigOne) != 0 { @@ -154,6 +163,20 @@ func (priv *PrivateKey) Validate() os.Error { // GenerateKey generates an RSA keypair of the given bit size. func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { + return GenerateMultiPrimeKey(rand, 2, bits) +} + +// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit +// size, as suggested in [1]. Although the public keys are compatible +// (actually, indistinguishable) from the 2-prime case, the private keys are +// not. Thus it may not be possible to export multi-prime private keys in +// certain formats or to subsequently import them into other code. +// +// Table 1 in [2] suggests maximum numbers of primes for a given size. +// +// [1] US patent 4405829 (1972, expired) +// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf +func GenerateMultiPrimeKey(rand io.Reader, nprimes int, bits int) (priv *PrivateKey, err os.Error) { priv = new(PrivateKey) // Smaller public exponents lead to faster public key // operations. Since the exponent must be coprime to @@ -165,100 +188,41 @@ func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { // [1] http://marc.info/?l=cryptography&m=115694833312008&w=2 priv.E = 3 - pminus1 := new(big.Int) - qminus1 := new(big.Int) - totient := new(big.Int) - - for { - p, err := randomPrime(rand, bits/2) - if err != nil { - return nil, err - } - - q, err := randomPrime(rand, bits/2) - if err != nil { - return nil, err - } - - if p.Cmp(q) == 0 { - continue - } - - n := new(big.Int).Mul(p, q) - pminus1.Sub(p, bigOne) - qminus1.Sub(q, bigOne) - totient.Mul(pminus1, qminus1) - - g := new(big.Int) - priv.D = new(big.Int) - y := new(big.Int) - e := big.NewInt(int64(priv.E)) - big.GcdInt(g, priv.D, y, e, totient) - - if g.Cmp(bigOne) == 0 { - priv.D.Add(priv.D, totient) - priv.P = p - priv.Q = q - priv.N = n - - break - } + if nprimes < 2 { + return nil, os.ErrorString("rsa.GenerateMultiPrimeKey: nprimes must be >= 2") } - return -} - -// Generate3PrimeKey generates a 3-prime RSA keypair of the given bit size, as -// suggested in [1]. Although the public keys are compatible (actually, -// indistinguishable) from the 2-prime case, the private keys are not. Thus it -// may not be possible to export 3-prime private keys in certain formats or to -// subsequently import them into other code. -// -// Table 1 in [2] suggests that size should be >= 1024 when using 3 primes. -// -// [1] US patent 4405829 (1972, expired) -// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf -func Generate3PrimeKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { - priv = new(PrivateKey) - priv.E = 3 - - pminus1 := new(big.Int) - qminus1 := new(big.Int) - rminus1 := new(big.Int) - totient := new(big.Int) + primes := make([]*big.Int, nprimes) +NextSetOfPrimes: for { - p, err := randomPrime(rand, bits/3) - if err != nil { - return nil, err - } - - todo := bits - p.BitLen() - q, err := randomPrime(rand, todo/2) - if err != nil { - return nil, err + todo := bits + for i := 0; i < nprimes; i++ { + primes[i], err = randomPrime(rand, todo/(nprimes-i)) + if err != nil { + return nil, err + } + todo -= primes[i].BitLen() } - todo -= q.BitLen() - r, err := randomPrime(rand, todo) - if err != nil { - return nil, err + // Make sure that primes is pairwise unequal. + for i, prime := range primes { + for j := 0; j < i; j++ { + if prime.Cmp(primes[j]) == 0 { + continue NextSetOfPrimes + } + } } - if p.Cmp(q) == 0 || - q.Cmp(r) == 0 || - r.Cmp(p) == 0 { - continue + n := new(big.Int).Set(bigOne) + totient := new(big.Int).Set(bigOne) + pminus1 := new(big.Int) + for _, prime := range primes { + n.Mul(n, prime) + pminus1.Sub(prime, bigOne) + totient.Mul(totient, pminus1) } - n := new(big.Int).Mul(p, q) - n.Mul(n, r) - pminus1.Sub(p, bigOne) - qminus1.Sub(q, bigOne) - rminus1.Sub(r, bigOne) - totient.Mul(pminus1, qminus1) - totient.Mul(totient, rminus1) - g := new(big.Int) priv.D = new(big.Int) y := new(big.Int) @@ -267,15 +231,14 @@ func Generate3PrimeKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error if g.Cmp(bigOne) == 0 { priv.D.Add(priv.D, totient) - priv.P = p - priv.Q = q - priv.R = r + priv.Primes = primes priv.N = n break } } + priv.Precompute() return } @@ -409,23 +372,34 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { return x, true } -// precompute performs some calculations that speed up private key operations +// Precompute performs some calculations that speed up private key operations // in the future. -func (priv *PrivateKey) precompute() { - priv.dP = new(big.Int).Sub(priv.P, bigOne) - priv.dP.Mod(priv.D, priv.dP) +func (priv *PrivateKey) Precompute() { + if priv.Precomputed.Dp != nil { + return + } - priv.dQ = new(big.Int).Sub(priv.Q, bigOne) - priv.dQ.Mod(priv.D, priv.dQ) + priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne) + priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp) - priv.qInv = new(big.Int).ModInverse(priv.Q, priv.P) + priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne) + priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq) - if priv.R != nil { - priv.dR = new(big.Int).Sub(priv.R, bigOne) - priv.dR.Mod(priv.D, priv.dR) + priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0]) - priv.pq = new(big.Int).Mul(priv.P, priv.Q) - priv.tr = new(big.Int).ModInverse(priv.pq, priv.R) + r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1]) + priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2) + for i := 2; i < len(priv.Primes); i++ { + prime := priv.Primes[i] + values := &priv.Precomputed.CRTValues[i-2] + + values.Exp = new(big.Int).Sub(prime, bigOne) + values.Exp.Mod(priv.D, values.Exp) + + values.R = new(big.Int).Set(r) + values.Coeff = new(big.Int).ModInverse(r, prime) + + r.Mul(r, prime) } } @@ -463,53 +437,41 @@ func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.E } bigE := big.NewInt(int64(priv.E)) rpowe := new(big.Int).Exp(r, bigE, priv.N) - c.Mul(c, rpowe) - c.Mod(c, priv.N) - } - - priv.rwMutex.RLock() - - if priv.dP == nil && priv.P != nil { - priv.rwMutex.RUnlock() - priv.rwMutex.Lock() - if priv.dP == nil && priv.P != nil { - priv.precompute() - } - priv.rwMutex.Unlock() - priv.rwMutex.RLock() + cCopy := new(big.Int).Set(c) + cCopy.Mul(cCopy, rpowe) + cCopy.Mod(cCopy, priv.N) + c = cCopy } - if priv.dP == nil { + if priv.Precomputed.Dp == nil { m = new(big.Int).Exp(c, priv.D, priv.N) } else { // We have the precalculated values needed for the CRT. - m = new(big.Int).Exp(c, priv.dP, priv.P) - m2 := new(big.Int).Exp(c, priv.dQ, priv.Q) + m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) + m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) m.Sub(m, m2) if m.Sign() < 0 { - m.Add(m, priv.P) + m.Add(m, priv.Primes[0]) } - m.Mul(m, priv.qInv) - m.Mod(m, priv.P) - m.Mul(m, priv.Q) + m.Mul(m, priv.Precomputed.Qinv) + m.Mod(m, priv.Primes[0]) + m.Mul(m, priv.Primes[1]) m.Add(m, m2) - if priv.dR != nil { - // 3-prime CRT. - m2.Exp(c, priv.dR, priv.R) + for i, values := range priv.Precomputed.CRTValues { + prime := priv.Primes[2+i] + m2.Exp(c, values.Exp, prime) m2.Sub(m2, m) - m2.Mul(m2, priv.tr) - m2.Mod(m2, priv.R) + m2.Mul(m2, values.Coeff) + m2.Mod(m2, prime) if m2.Sign() < 0 { - m2.Add(m2, priv.R) + m2.Add(m2, prime) } - m2.Mul(m2, priv.pq) + m2.Mul(m2, values.R) m.Add(m, m2) } } - priv.rwMutex.RUnlock() - if ir != nil { // Unblind. m.Mul(m, ir) diff --git a/src/pkg/crypto/rsa/rsa_test.go b/src/pkg/crypto/rsa/rsa_test.go index d8a936eb6..c36bca1cd 100644 --- a/src/pkg/crypto/rsa/rsa_test.go +++ b/src/pkg/crypto/rsa/rsa_test.go @@ -30,7 +30,20 @@ func Test3PrimeKeyGeneration(t *testing.T) { } size := 768 - priv, err := Generate3PrimeKey(rand.Reader, size) + priv, err := GenerateMultiPrimeKey(rand.Reader, 3, size) + if err != nil { + t.Errorf("failed to generate key") + } + testKeyBasics(t, priv) +} + +func Test4PrimeKeyGeneration(t *testing.T) { + if testing.Short() { + return + } + + size := 768 + priv, err := GenerateMultiPrimeKey(rand.Reader, 4, size) if err != nil { t.Errorf("failed to generate key") } @@ -45,6 +58,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { pub := &priv.PublicKey m := big.NewInt(42) c := encrypt(new(big.Int), pub, m) + m2, err := decrypt(nil, priv, c) if err != nil { t.Errorf("error while decrypting: %s", err) @@ -59,7 +73,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { t.Errorf("error while decrypting (blind): %s", err) } if m.Cmp(m3) != 0 { - t.Errorf("(blind) got:%v, want:%v", m3, m) + t.Errorf("(blind) got:%v, want:%v (%#v)", m3, m, priv) } } @@ -77,10 +91,12 @@ func BenchmarkRSA2048Decrypt(b *testing.B) { E: 3, }, D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"), - P: fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), - Q: fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), + Primes: []*big.Int{ + fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), + fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), + }, } - priv.precompute() + priv.Precompute() c := fromBase10("1000") @@ -99,11 +115,13 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) { E: 3, }, D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), - P: fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), - Q: fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), - R: fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + Primes: []*big.Int{ + fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), + fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), + fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + }, } - priv.precompute() + priv.Precompute() c := fromBase10("1000") diff --git a/src/pkg/crypto/sha1/sha1.go b/src/pkg/crypto/sha1/sha1.go index e6aa096e2..788d1ff55 100644 --- a/src/pkg/crypto/sha1/sha1.go +++ b/src/pkg/crypto/sha1/sha1.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA1 hash algorithm as defined in RFC 3174. +// Package sha1 implements the SHA1 hash algorithm as defined in RFC 3174. package sha1 import ( diff --git a/src/pkg/crypto/sha256/sha256.go b/src/pkg/crypto/sha256/sha256.go index 69b356b4e..a2c058d18 100644 --- a/src/pkg/crypto/sha256/sha256.go +++ b/src/pkg/crypto/sha256/sha256.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA224 and SHA256 hash algorithms as defined in FIPS 180-2. +// Package sha256 implements the SHA224 and SHA256 hash algorithms as defined +// in FIPS 180-2. package sha256 import ( diff --git a/src/pkg/crypto/sha512/sha512.go b/src/pkg/crypto/sha512/sha512.go index 7e9f330e5..78f5fe26f 100644 --- a/src/pkg/crypto/sha512/sha512.go +++ b/src/pkg/crypto/sha512/sha512.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA384 and SHA512 hash algorithms as defined in FIPS 180-2. +// Package sha512 implements the SHA384 and SHA512 hash algorithms as defined +// in FIPS 180-2. package sha512 import ( diff --git a/src/pkg/crypto/subtle/constant_time.go b/src/pkg/crypto/subtle/constant_time.go index a3d70b9c9..57dbe9db5 100644 --- a/src/pkg/crypto/subtle/constant_time.go +++ b/src/pkg/crypto/subtle/constant_time.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements functions that are often useful in cryptographic +// Package subtle implements functions that are often useful in cryptographic // code but require careful thought to use correctly. package subtle diff --git a/src/pkg/crypto/tls/Makefile b/src/pkg/crypto/tls/Makefile index f8ec1511a..000314be5 100644 --- a/src/pkg/crypto/tls/Makefile +++ b/src/pkg/crypto/tls/Makefile @@ -7,7 +7,6 @@ include ../../../Make.inc TARG=crypto/tls GOFILES=\ alert.go\ - ca_set.go\ cipher_suites.go\ common.go\ conn.go\ diff --git a/src/pkg/crypto/tls/ca_set.go b/src/pkg/crypto/tls/ca_set.go deleted file mode 100644 index ae00ac558..000000000 --- a/src/pkg/crypto/tls/ca_set.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto/x509" - "encoding/pem" - "strings" -) - -// A CASet is a set of certificates. -type CASet struct { - bySubjectKeyId map[string][]*x509.Certificate - byName map[string][]*x509.Certificate -} - -// NewCASet returns a new, empty CASet. -func NewCASet() *CASet { - return &CASet{ - make(map[string][]*x509.Certificate), - make(map[string][]*x509.Certificate), - } -} - -func nameToKey(name *x509.Name) string { - return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName -} - -// FindVerifiedParent attempts to find the certificate in s which has signed -// the given certificate. If no such certificate can be found or the signature -// doesn't match, it returns nil. -func (s *CASet) FindVerifiedParent(cert *x509.Certificate) (parent *x509.Certificate) { - var candidates []*x509.Certificate - - if len(cert.AuthorityKeyId) > 0 { - candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] - } - if len(candidates) == 0 { - candidates = s.byName[nameToKey(&cert.Issuer)] - } - - for _, c := range candidates { - if cert.CheckSignatureFrom(c) == nil { - return c - } - } - - return nil -} - -// AddCert adds a certificate to the set -func (s *CASet) AddCert(cert *x509.Certificate) { - if len(cert.SubjectKeyId) > 0 { - keyId := string(cert.SubjectKeyId) - s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], cert) - } - name := nameToKey(&cert.Subject) - s.byName[name] = append(s.byName[name], cert) -} - -// SetFromPEM attempts to parse a series of PEM encoded root certificates. It -// appends any certificates found to s and returns true if any certificates -// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will -// contains the system wide set of root CAs in a format suitable for this -// function. -func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { - for len(pemCerts) > 0 { - var block *pem.Block - block, pemCerts = pem.Decode(pemCerts) - if block == nil { - break - } - if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { - continue - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - continue - } - - s.AddCert(cert) - ok = true - } - - return -} diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index fb2916ae0..204d25531 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -122,7 +122,7 @@ type Config struct { // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *CASet + RootCAs *x509.CertPool // NextProtos is a list of supported, application level protocols. NextProtos []string @@ -158,7 +158,7 @@ func (c *Config) time() int64 { return t() } -func (c *Config) rootCAs() *CASet { +func (c *Config) rootCAs() *x509.CertPool { s := c.RootCAs if s == nil { s = defaultRoots() @@ -178,6 +178,9 @@ func (c *Config) cipherSuites() []uint16 { type Certificate struct { Certificate [][]byte PrivateKey *rsa.PrivateKey + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte } // A TLS record. @@ -221,7 +224,7 @@ var certFiles = []string{ var once sync.Once -func defaultRoots() *CASet { +func defaultRoots() *x509.CertPool { once.Do(initDefaults) return varDefaultRoots } @@ -236,14 +239,14 @@ func initDefaults() { initDefaultCipherSuites() } -var varDefaultRoots *CASet +var varDefaultRoots *x509.CertPool func initDefaultRoots() { - roots := NewCASet() + roots := x509.NewCertPool() for _, file := range certFiles { data, err := ioutil.ReadFile(file) if err == nil { - roots.SetFromPEM(data) + roots.AppendCertsFromPEM(data) break } } diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go index b94e235c8..63d56310c 100644 --- a/src/pkg/crypto/tls/conn.go +++ b/src/pkg/crypto/tls/conn.go @@ -34,6 +34,9 @@ type Conn struct { cipherSuite uint16 ocspResponse []byte // stapled OCSP response peerCertificates []*x509.Certificate + // verifedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate clientProtocol string clientProtocolFallback bool diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index 540b25c87..c758c96d4 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -88,7 +88,6 @@ func (c *Conn) clientHandshake() os.Error { finishedHash.Write(certMsg.marshal()) certs := make([]*x509.Certificate, len(certMsg.certificates)) - chain := NewCASet() for i, asn1Data := range certMsg.certificates { cert, err := x509.ParseCertificate(asn1Data) if err != nil { @@ -96,47 +95,29 @@ func (c *Conn) clientHandshake() os.Error { return os.ErrorString("failed to parse certificate from server: " + err.String()) } certs[i] = cert - chain.AddCert(cert) } // If we don't have a root CA set configured then anything is accepted. // TODO(rsc): Find certificates for OS X 10.6. - for cur := certs[0]; c.config.RootCAs != nil; { - parent := c.config.RootCAs.FindVerifiedParent(cur) - if parent != nil { - break + if c.config.RootCAs != nil { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), } - parent = chain.FindVerifiedParent(cur) - if parent == nil { - c.sendAlert(alertBadCertificate) - return os.ErrorString("could not find root certificate for chain") + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) } - - if !parent.BasicConstraintsValid || !parent.IsCA { + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { c.sendAlert(alertBadCertificate) - return os.ErrorString("intermediate certificate does not have CA bit set") + return err } - // KeyUsage status flags are ignored. From Engineering - // Security, Peter Gutmann: A European government CA marked its - // signing certificates as being valid for encryption only, but - // no-one noticed. Another European CA marked its signature - // keys as not being valid for signatures. A different CA - // marked its own trusted root certificate as being invalid for - // certificate signing. Another national CA distributed a - // certificate to be used to encrypt data for the country’s tax - // authority that was marked as only being usable for digital - // signatures but not for encryption. Yet another CA reversed - // the order of the bit flags in the keyUsage due to confusion - // over encoding endianness, essentially setting a random - // keyUsage in certificates that it issued. Another CA created - // a self-invalidating certificate by adding a certificate - // policy statement stipulating that the certificate had to be - // used strictly as specified in the keyUsage, and a keyUsage - // containing a flag indicating that the RSA encryption key - // could only be used for Diffie-Hellman key agreement. - - cur = parent } if _, ok := certs[0].PublicKey.(*rsa.PublicKey); !ok { @@ -145,7 +126,7 @@ func (c *Conn) clientHandshake() os.Error { c.peerCertificates = certs - if serverHello.certStatus { + if serverHello.ocspStapling { msg, err = c.readHandshake() if err != nil { return err diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go index e5e856271..6645adce4 100644 --- a/src/pkg/crypto/tls/handshake_messages.go +++ b/src/pkg/crypto/tls/handshake_messages.go @@ -306,7 +306,7 @@ type serverHelloMsg struct { compressionMethod uint8 nextProtoNeg bool nextProtos []string - certStatus bool + ocspStapling bool } func (m *serverHelloMsg) marshal() []byte { @@ -327,7 +327,7 @@ func (m *serverHelloMsg) marshal() []byte { nextProtoLen += len(m.nextProtos) extensionsLength += nextProtoLen } - if m.certStatus { + if m.ocspStapling { numExtensions++ } if numExtensions > 0 { @@ -373,7 +373,7 @@ func (m *serverHelloMsg) marshal() []byte { z = z[1+l:] } } - if m.certStatus { + if m.ocspStapling { z[0] = byte(extensionStatusRequest >> 8) z[1] = byte(extensionStatusRequest) z = z[4:] @@ -406,7 +406,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtoNeg = false m.nextProtos = nil - m.certStatus = false + m.ocspStapling = false if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -450,7 +450,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { if length > 0 { return false } - m.certStatus = true + m.ocspStapling = true } data = data[length:] } diff --git a/src/pkg/crypto/tls/handshake_messages_test.go b/src/pkg/crypto/tls/handshake_messages_test.go index f5e94e269..23f729dd9 100644 --- a/src/pkg/crypto/tls/handshake_messages_test.go +++ b/src/pkg/crypto/tls/handshake_messages_test.go @@ -32,7 +32,7 @@ type testMessage interface { func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) for i, iface := range tests { - ty := reflect.NewValue(iface).Type() + ty := reflect.ValueOf(iface).Type() n := 100 if testing.Short() { @@ -125,7 +125,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.supportedCurves[i] = uint16(rand.Intn(30000)) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -146,7 +146,7 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -156,7 +156,7 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { for i := 0; i < numCerts; i++ { m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -167,13 +167,13 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value for i := 0; i < numCAs; i++ { m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateVerifyMsg{} m.signature = randomBytes(rand.Intn(15)+1, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -184,23 +184,23 @@ func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { } else { m.statusType = 42 } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &clientKeyExchangeMsg{} m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &finishedMsg{} m.verifyData = randomBytes(12, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &nextProtoMsg{} m.proto = randomString(rand.Intn(255), rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go index 809c8c15e..37c8d154a 100644 --- a/src/pkg/crypto/tls/handshake_server.go +++ b/src/pkg/crypto/tls/handshake_server.go @@ -103,6 +103,9 @@ FindCipherSuite: hello.nextProtoNeg = true hello.nextProtos = config.NextProtos } + if clientHello.ocspStapling && len(config.Certificates[0].OCSPStaple) > 0 { + hello.ocspStapling = true + } finishedHash.Write(hello.marshal()) c.writeRecord(recordTypeHandshake, hello.marshal()) @@ -116,6 +119,14 @@ FindCipherSuite: finishedHash.Write(certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal()) + if hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.statusType = statusTypeOCSP + certStatus.response = config.Certificates[0].OCSPStaple + finishedHash.Write(certStatus.marshal()) + c.writeRecord(recordTypeHandshake, certStatus.marshal()) + } + keyAgreement := suite.ka() skx, err := keyAgreement.generateServerKeyExchange(config, clientHello, hello) diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go index 6beb6a9f6..5a1e754dc 100644 --- a/src/pkg/crypto/tls/handshake_server_test.go +++ b/src/pkg/crypto/tls/handshake_server_test.go @@ -188,8 +188,10 @@ var testPrivateKey = &rsa.PrivateKey{ E: 65537, }, D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"), - P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), - Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + Primes: []*big.Int{ + bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), + bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + }, } // Script of interaction with gnutls implementation. diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index 7de44bbd2..7d0bb9f34 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package partially implements the TLS 1.1 protocol, as specified in RFC 4346. +// Package tls partially implements the TLS 1.1 protocol, as specified in RFC +// 4346. package tls import ( diff --git a/src/pkg/crypto/twofish/twofish.go b/src/pkg/crypto/twofish/twofish.go index 62253e797..9303f03ff 100644 --- a/src/pkg/crypto/twofish/twofish.go +++ b/src/pkg/crypto/twofish/twofish.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements Bruce Schneier's Twofish encryption algorithm. +// Package twofish implements Bruce Schneier's Twofish encryption algorithm. package twofish // Twofish is defined in http://www.schneier.com/paper-twofish-paper.pdf [TWOFISH] diff --git a/src/pkg/crypto/x509/Makefile b/src/pkg/crypto/x509/Makefile index 329a61b7c..14ffd095f 100644 --- a/src/pkg/crypto/x509/Makefile +++ b/src/pkg/crypto/x509/Makefile @@ -6,6 +6,8 @@ include ../../../Make.inc TARG=crypto/x509 GOFILES=\ + cert_pool.go\ + verify.go\ x509.go\ include ../../../Make.pkg diff --git a/src/pkg/crypto/x509/cert_pool.go b/src/pkg/crypto/x509/cert_pool.go new file mode 100644 index 000000000..c295fd97e --- /dev/null +++ b/src/pkg/crypto/x509/cert_pool.go @@ -0,0 +1,105 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/pem" + "strings" +) + +// Roots is a set of certificates. +type CertPool struct { + bySubjectKeyId map[string][]int + byName map[string][]int + certs []*Certificate +} + +// NewCertPool returns a new, empty CertPool. +func NewCertPool() *CertPool { + return &CertPool{ + make(map[string][]int), + make(map[string][]int), + nil, + } +} + +func nameToKey(name *Name) string { + return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName +} + +// findVerifiedParents attempts to find certificates in s which have signed the +// given certificate. If no such certificate can be found or the signature +// doesn't match, it returns nil. +func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int) { + var candidates []int + + if len(cert.AuthorityKeyId) > 0 { + candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] + } + if len(candidates) == 0 { + candidates = s.byName[nameToKey(&cert.Issuer)] + } + + for _, c := range candidates { + if cert.CheckSignatureFrom(s.certs[c]) == nil { + parents = append(parents, c) + } + } + + return +} + +// AddCert adds a certificate to a pool. +func (s *CertPool) AddCert(cert *Certificate) { + if cert == nil { + panic("adding nil Certificate to CertPool") + } + + // Check that the certificate isn't being added twice. + for _, c := range s.certs { + if c.Equal(cert) { + return + } + } + + n := len(s.certs) + s.certs = append(s.certs, cert) + + if len(cert.SubjectKeyId) > 0 { + keyId := string(cert.SubjectKeyId) + s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n) + } + name := nameToKey(&cert.Subject) + s.byName[name] = append(s.byName[name], n) +} + +// AppendCertsFromPEM attempts to parse a series of PEM encoded root +// certificates. It appends any certificates found to s and returns true if any +// certificates were successfully parsed. +// +// On many Linux systems, /etc/ssl/cert.pem will contains the system wide set +// of root CAs in a format suitable for this function. +func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := ParseCertificate(block.Bytes) + if err != nil { + continue + } + + s.AddCert(cert) + ok = true + } + + return +} diff --git a/src/pkg/crypto/x509/verify.go b/src/pkg/crypto/x509/verify.go new file mode 100644 index 000000000..9145880a2 --- /dev/null +++ b/src/pkg/crypto/x509/verify.go @@ -0,0 +1,239 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "os" + "strings" + "time" +) + +type InvalidReason int + +const ( + // NotAuthorizedToSign results when a certificate is signed by another + // which isn't marked as a CA certificate. + NotAuthorizedToSign InvalidReason = iota + // Expired results when a certificate has expired, based on the time + // given in the VerifyOptions. + Expired + // CANotAuthorizedForThisName results when an intermediate or root + // certificate has a name constraint which doesn't include the name + // being checked. + CANotAuthorizedForThisName +) + +// CertificateInvalidError results when an odd error occurs. Users of this +// library probably want to handle all these errors uniformly. +type CertificateInvalidError struct { + Cert *Certificate + Reason InvalidReason +} + +func (e CertificateInvalidError) String() string { + switch e.Reason { + case NotAuthorizedToSign: + return "x509: certificate is not authorized to sign other other certificates" + case Expired: + return "x509: certificate has expired or is not yet valid" + case CANotAuthorizedForThisName: + return "x509: a root or intermediate certificate is not authorized to sign in this domain" + } + return "x509: unknown error" +} + +// HostnameError results when the set of authorized names doesn't match the +// requested name. +type HostnameError struct { + Certificate *Certificate + Host string +} + +func (h HostnameError) String() string { + var valid string + c := h.Certificate + if len(c.DNSNames) > 0 { + valid = strings.Join(c.DNSNames, ", ") + } else { + valid = c.Subject.CommonName + } + return "certificate is valid for " + valid + ", not " + h.Host +} + + +// UnknownAuthorityError results when the certificate issuer is unknown +type UnknownAuthorityError struct { + cert *Certificate +} + +func (e UnknownAuthorityError) String() string { + return "x509: certificate signed by unknown authority" +} + +// VerifyOptions contains parameters for Certificate.Verify. It's a structure +// because other PKIX verification APIs have ended up needing many options. +type VerifyOptions struct { + DNSName string + Intermediates *CertPool + Roots *CertPool + CurrentTime int64 // if 0, the current system time is used. +} + +const ( + leafCertificate = iota + intermediateCertificate + rootCertificate +) + +// isValid performs validity checks on the c. +func (c *Certificate) isValid(certType int, opts *VerifyOptions) os.Error { + if opts.CurrentTime < c.NotBefore.Seconds() || + opts.CurrentTime > c.NotAfter.Seconds() { + return CertificateInvalidError{c, Expired} + } + + if len(c.PermittedDNSDomains) > 0 { + for _, domain := range c.PermittedDNSDomains { + if opts.DNSName == domain || + (strings.HasSuffix(opts.DNSName, domain) && + len(opts.DNSName) >= 1+len(domain) && + opts.DNSName[len(opts.DNSName)-len(domain)-1] == '.') { + continue + } + + return CertificateInvalidError{c, CANotAuthorizedForThisName} + } + } + + // KeyUsage status flags are ignored. From Engineering Security, Peter + // Gutmann: A European government CA marked its signing certificates as + // being valid for encryption only, but no-one noticed. Another + // European CA marked its signature keys as not being valid for + // signatures. A different CA marked its own trusted root certificate + // as being invalid for certificate signing. Another national CA + // distributed a certificate to be used to encrypt data for the + // country’s tax authority that was marked as only being usable for + // digital signatures but not for encryption. Yet another CA reversed + // the order of the bit flags in the keyUsage due to confusion over + // encoding endianness, essentially setting a random keyUsage in + // certificates that it issued. Another CA created a self-invalidating + // certificate by adding a certificate policy statement stipulating + // that the certificate had to be used strictly as specified in the + // keyUsage, and a keyUsage containing a flag indicating that the RSA + // encryption key could only be used for Diffie-Hellman key agreement. + + if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) { + return CertificateInvalidError{c, NotAuthorizedToSign} + } + + return nil +} + +// Verify attempts to verify c by building one or more chains from c to a +// certificate in opts.roots, using certificates in opts.Intermediates if +// needed. If successful, it returns one or chains where the first element of +// the chain is c and the last element is from opts.Roots. +// +// WARNING: this doesn't do any revocation checking. +func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err os.Error) { + if opts.CurrentTime == 0 { + opts.CurrentTime = time.Seconds() + } + err = c.isValid(leafCertificate, &opts) + if err != nil { + return + } + if len(opts.DNSName) > 0 { + err = c.VerifyHostname(opts.DNSName) + if err != nil { + return + } + } + return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts) +} + +func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate { + n := make([]*Certificate, len(chain)+1) + copy(n, chain) + n[len(chain)] = cert + return n +} + +func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err os.Error) { + for _, rootNum := range opts.Roots.findVerifiedParents(c) { + root := opts.Roots.certs[rootNum] + err = root.isValid(rootCertificate, opts) + if err != nil { + continue + } + chains = append(chains, appendToFreshChain(currentChain, root)) + } + + for _, intermediateNum := range opts.Intermediates.findVerifiedParents(c) { + intermediate := opts.Intermediates.certs[intermediateNum] + err = intermediate.isValid(intermediateCertificate, opts) + if err != nil { + continue + } + var childChains [][]*Certificate + childChains, ok := cache[intermediateNum] + if !ok { + childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts) + cache[intermediateNum] = childChains + } + chains = append(chains, childChains...) + } + + if len(chains) > 0 { + err = nil + } + + if len(chains) == 0 && err == nil { + err = UnknownAuthorityError{c} + } + + return +} + +func matchHostnames(pattern, host string) bool { + if len(pattern) == 0 || len(host) == 0 { + return false + } + + patternParts := strings.Split(pattern, ".", -1) + hostParts := strings.Split(host, ".", -1) + + if len(patternParts) != len(hostParts) { + return false + } + + for i, patternPart := range patternParts { + if patternPart == "*" { + continue + } + if patternPart != hostParts[i] { + return false + } + } + + return true +} + +// VerifyHostname returns nil if c is a valid certificate for the named host. +// Otherwise it returns an os.Error describing the mismatch. +func (c *Certificate) VerifyHostname(h string) os.Error { + if len(c.DNSNames) > 0 { + for _, match := range c.DNSNames { + if matchHostnames(match, h) { + return nil + } + } + // If Subject Alt Name is given, we ignore the common name. + } else if matchHostnames(c.Subject.CommonName, h) { + return nil + } + + return HostnameError{c, h} +} diff --git a/src/pkg/crypto/x509/verify_test.go b/src/pkg/crypto/x509/verify_test.go new file mode 100644 index 000000000..6a103dcfb --- /dev/null +++ b/src/pkg/crypto/x509/verify_test.go @@ -0,0 +1,390 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/pem" + "os" + "strings" + "testing" +) + +type verifyTest struct { + leaf string + intermediates []string + roots []string + currentTime int64 + dnsName string + + errorCallback func(*testing.T, int, os.Error) bool + expectedChains [][]string +} + +var verifyTests = []verifyTest{ + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.example.com", + + errorCallback: expectHostnameError, + }, + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1, + dnsName: "www.example.com", + + errorCallback: expectExpired, + }, + { + leaf: googleLeaf, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + errorCallback: expectAuthorityUnknown, + }, + { + leaf: googleLeaf, + intermediates: []string{verisignRoot, thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: googleLeaf, + intermediates: []string{verisignRoot, thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: dnssecExpLeaf, + intermediates: []string{startComIntermediate}, + roots: []string{startComRoot}, + currentTime: 1302726541, + + expectedChains: [][]string{ + []string{"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority"}, + }, + }, +} + +func expectHostnameError(t *testing.T, i int, err os.Error) (ok bool) { + if _, ok := err.(HostnameError); !ok { + t.Errorf("#%d: error was not a HostnameError: %s", i, err) + return false + } + return true +} + +func expectExpired(t *testing.T, i int, err os.Error) (ok bool) { + if inval, ok := err.(CertificateInvalidError); !ok || inval.Reason != Expired { + t.Errorf("#%d: error was not Expired: %s", i, err) + return false + } + return true +} + +func expectAuthorityUnknown(t *testing.T, i int, err os.Error) (ok bool) { + if _, ok := err.(UnknownAuthorityError); !ok { + t.Errorf("#%d: error was not UnknownAuthorityError: %s", i, err) + return false + } + return true +} + +func certificateFromPEM(pemBytes string) (*Certificate, os.Error) { + block, _ := pem.Decode([]byte(pemBytes)) + if block == nil { + return nil, os.ErrorString("failed to decode PEM") + } + return ParseCertificate(block.Bytes) +} + +func TestVerify(t *testing.T) { + for i, test := range verifyTests { + opts := VerifyOptions{ + Roots: NewCertPool(), + Intermediates: NewCertPool(), + DNSName: test.dnsName, + CurrentTime: test.currentTime, + } + + for j, root := range test.roots { + ok := opts.Roots.AppendCertsFromPEM([]byte(root)) + if !ok { + t.Errorf("#%d: failed to parse root #%d", i, j) + return + } + } + + for j, intermediate := range test.intermediates { + ok := opts.Intermediates.AppendCertsFromPEM([]byte(intermediate)) + if !ok { + t.Errorf("#%d: failed to parse intermediate #%d", i, j) + return + } + } + + leaf, err := certificateFromPEM(test.leaf) + if err != nil { + t.Errorf("#%d: failed to parse leaf: %s", i, err) + return + } + + chains, err := leaf.Verify(opts) + + if test.errorCallback == nil && err != nil { + t.Errorf("#%d: unexpected error: %s", i, err) + } + if test.errorCallback != nil { + if !test.errorCallback(t, i, err) { + return + } + } + + if len(chains) != len(test.expectedChains) { + t.Errorf("#%d: wanted %d chains, got %d", i, len(test.expectedChains), len(chains)) + } + + // We check that each returned chain matches a chain from + // expectedChains but an entry in expectedChains can't match + // two chains. + seenChains := make([]bool, len(chains)) + NextOutputChain: + for _, chain := range chains { + TryNextExpected: + for j, expectedChain := range test.expectedChains { + if seenChains[j] { + continue + } + if len(chain) != len(expectedChain) { + continue + } + for k, cert := range chain { + if strings.Index(nameToKey(&cert.Subject), expectedChain[k]) == -1 { + continue TryNextExpected + } + } + // we matched + seenChains[j] = true + continue NextOutputChain + } + t.Errorf("#%d: No expected chain matched %s", i, chainToDebugString(chain)) + } + } +} + +func chainToDebugString(chain []*Certificate) string { + var chainStr string + for _, cert := range chain { + if len(chainStr) > 0 { + chainStr += " -> " + } + chainStr += nameToKey(&cert.Subject) + } + return chainStr +} + +const verisignRoot = `-----BEGIN CERTIFICATE----- +MIICPDCCAaUCEHC65B0Q2Sk0tjjKewPMur8wDQYJKoZIhvcNAQECBQAwXzELMAkG +A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz +cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2 +MDEyOTAwMDAwMFoXDTI4MDgwMTIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV +BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt +YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN +ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE +BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is +I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G +CSqGSIb3DQEBAgUAA4GBALtMEivPLCYATxQT3ab7/AoRhIzzKBxnki98tsX63/Do +lbwdj2wsqFHMc9ikwFPwTtYmwHYBV4GSXiHx0bH/59AhWM1pF+NEHJwZRDmJXNyc +AA9WjQKZ7aKQRUzkuxCkPfAyAw7xzvjoyVGM5mKf5p/AfbdynMk2OmufTqj/ZA1k +-----END CERTIFICATE----- +` + +const thawteIntermediate = `-----BEGIN CERTIFICATE----- +MIIDIzCCAoygAwIBAgIEMAAAAjANBgkqhkiG9w0BAQUFADBfMQswCQYDVQQGEwJV +UzEXMBUGA1UEChMOVmVyaVNpZ24sIEluYy4xNzA1BgNVBAsTLkNsYXNzIDMgUHVi +bGljIFByaW1hcnkgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwHhcNMDQwNTEzMDAw +MDAwWhcNMTQwNTEyMjM1OTU5WjBMMQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhh +d3RlIENvbnN1bHRpbmcgKFB0eSkgTHRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBD +QTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1NNn0I0Vf67NMf59HZGhPwtx +PKzMyGT7Y/wySweUvW+Aui/hBJPAM/wJMyPpC3QrccQDxtLN4i/1CWPN/0ilAL/g +5/OIty0y3pg25gqtAHvEZEo7hHUD8nCSfQ5i9SGraTaEMXWQ+L/HbIgbBpV8yeWo +3nWhLHpo39XKHIdYYBkCAwEAAaOB/jCB+zASBgNVHRMBAf8ECDAGAQH/AgEAMAsG +A1UdDwQEAwIBBjARBglghkgBhvhCAQEEBAMCAQYwKAYDVR0RBCEwH6QdMBsxGTAX +BgNVBAMTEFByaXZhdGVMYWJlbDMtMTUwMQYDVR0fBCowKDAmoCSgIoYgaHR0cDov +L2NybC52ZXJpc2lnbi5jb20vcGNhMy5jcmwwMgYIKwYBBQUHAQEEJjAkMCIGCCsG +AQUFBzABhhZodHRwOi8vb2NzcC50aGF3dGUuY29tMDQGA1UdJQQtMCsGCCsGAQUF +BwMBBggrBgEFBQcDAgYJYIZIAYb4QgQBBgpghkgBhvhFAQgBMA0GCSqGSIb3DQEB +BQUAA4GBAFWsY+reod3SkF+fC852vhNRj5PZBSvIG3dLrWlQoe7e3P3bB+noOZTc +q3J5Lwa/q4FwxKjt6lM07e8eU9kGx1Yr0Vz00YqOtCuxN5BICEIlxT6Ky3/rbwTR +bcV0oveifHtgPHfNDs5IAn8BL7abN+AqKjbc1YXWrOU/VG+WHgWv +-----END CERTIFICATE----- +` + +const googleLeaf = `-----BEGIN CERTIFICATE----- +MIIDITCCAoqgAwIBAgIQL9+89q6RUm0PmqPfQDQ+mjANBgkqhkiG9w0BAQUFADBM +MQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhhd3RlIENvbnN1bHRpbmcgKFB0eSkg +THRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBDQTAeFw0wOTEyMTgwMDAwMDBaFw0x +MTEyMTgyMzU5NTlaMGgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh +MRYwFAYDVQQHFA1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKFApHb29nbGUgSW5jMRcw +FQYDVQQDFA53d3cuZ29vZ2xlLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC +gYEA6PmGD5D6htffvXImttdEAoN4c9kCKO+IRTn7EOh8rqk41XXGOOsKFQebg+jN +gtXj9xVoRaELGYW84u+E593y17iYwqG7tcFR39SDAqc9BkJb4SLD3muFXxzW2k6L +05vuuWciKh0R73mkszeK9P4Y/bz5RiNQl/Os/CRGK1w7t0UCAwEAAaOB5zCB5DAM +BgNVHRMBAf8EAjAAMDYGA1UdHwQvMC0wK6ApoCeGJWh0dHA6Ly9jcmwudGhhd3Rl +LmNvbS9UaGF3dGVTR0NDQS5jcmwwKAYDVR0lBCEwHwYIKwYBBQUHAwEGCCsGAQUF +BwMCBglghkgBhvhCBAEwcgYIKwYBBQUHAQEEZjBkMCIGCCsGAQUFBzABhhZodHRw +Oi8vb2NzcC50aGF3dGUuY29tMD4GCCsGAQUFBzAChjJodHRwOi8vd3d3LnRoYXd0 +ZS5jb20vcmVwb3NpdG9yeS9UaGF3dGVfU0dDX0NBLmNydDANBgkqhkiG9w0BAQUF +AAOBgQCfQ89bxFApsb/isJr/aiEdLRLDLE5a+RLizrmCUi3nHX4adpaQedEkUjh5 +u2ONgJd8IyAPkU0Wueru9G2Jysa9zCRo1kNbzipYvzwY4OA8Ys+WAi0oR1A04Se6 +z5nRUP8pJcA2NhUzUnC+MY+f6H/nEQyNv4SgQhqAibAxWEEHXw== +-----END CERTIFICATE-----` + +const dnssecExpLeaf = `-----BEGIN CERTIFICATE----- +MIIGzTCCBbWgAwIBAgIDAdD6MA0GCSqGSIb3DQEBBQUAMIGMMQswCQYDVQQGEwJJ +TDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0 +YWwgQ2VydGlmaWNhdGUgU2lnbmluZzE4MDYGA1UEAxMvU3RhcnRDb20gQ2xhc3Mg +MSBQcmltYXJ5IEludGVybWVkaWF0ZSBTZXJ2ZXIgQ0EwHhcNMTAwNzA0MTQ1MjQ1 +WhcNMTEwNzA1MTA1NzA0WjCBwTEgMB4GA1UEDRMXMjIxMTM3LWxpOWE5dHhJRzZM +NnNyVFMxCzAJBgNVBAYTAlVTMR4wHAYDVQQKExVQZXJzb25hIE5vdCBWYWxpZGF0 +ZWQxKTAnBgNVBAsTIFN0YXJ0Q29tIEZyZWUgQ2VydGlmaWNhdGUgTWVtYmVyMRsw +GQYDVQQDExJ3d3cuZG5zc2VjLWV4cC5vcmcxKDAmBgkqhkiG9w0BCQEWGWhvc3Rt +YXN0ZXJAZG5zc2VjLWV4cC5vcmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQDEdF/22vaxrPbqpgVYMWi+alfpzBctpbfLBdPGuqOazJdCT0NbWcK8/+B4 +X6OlSOURNIlwLzhkmwVsWdVv6dVSaN7d4yI/fJkvgfDB9+au+iBJb6Pcz8ULBfe6 +D8HVvqKdORp6INzHz71z0sghxrQ0EAEkoWAZLh+kcn2ZHdcmZaBNUfjmGbyU6PRt +RjdqoP+owIaC1aktBN7zl4uO7cRjlYFdusINrh2kPP02KAx2W84xjxX1uyj6oS6e +7eBfvcwe8czW/N1rbE0CoR7h9+HnIrjnVG9RhBiZEiw3mUmF++Up26+4KTdRKbu3 ++BL4yMpfd66z0+zzqu+HkvyLpFn5AgMBAAGjggL/MIIC+zAJBgNVHRMEAjAAMAsG +A1UdDwQEAwIDqDATBgNVHSUEDDAKBggrBgEFBQcDATAdBgNVHQ4EFgQUy04I5guM +drzfh2JQaXhgV86+4jUwHwYDVR0jBBgwFoAU60I00Jiwq5/0G2sI98xkLu8OLEUw +LQYDVR0RBCYwJIISd3d3LmRuc3NlYy1leHAub3Jngg5kbnNzZWMtZXhwLm9yZzCC +AUIGA1UdIASCATkwggE1MIIBMQYLKwYBBAGBtTcBAgIwggEgMC4GCCsGAQUFBwIB +FiJodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9wb2xpY3kucGRmMDQGCCsGAQUFBwIB +FihodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9pbnRlcm1lZGlhdGUucGRmMIG3Bggr +BgEFBQcCAjCBqjAUFg1TdGFydENvbSBMdGQuMAMCAQEagZFMaW1pdGVkIExpYWJp +bGl0eSwgc2VlIHNlY3Rpb24gKkxlZ2FsIExpbWl0YXRpb25zKiBvZiB0aGUgU3Rh +cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgUG9saWN5IGF2YWlsYWJsZSBh +dCBodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9wb2xpY3kucGRmMGEGA1UdHwRaMFgw +KqAooCaGJGh0dHA6Ly93d3cuc3RhcnRzc2wuY29tL2NydDEtY3JsLmNybDAqoCig +JoYkaHR0cDovL2NybC5zdGFydHNzbC5jb20vY3J0MS1jcmwuY3JsMIGOBggrBgEF +BQcBAQSBgTB/MDkGCCsGAQUFBzABhi1odHRwOi8vb2NzcC5zdGFydHNzbC5jb20v +c3ViL2NsYXNzMS9zZXJ2ZXIvY2EwQgYIKwYBBQUHMAKGNmh0dHA6Ly93d3cuc3Rh +cnRzc2wuY29tL2NlcnRzL3N1Yi5jbGFzczEuc2VydmVyLmNhLmNydDAjBgNVHRIE +HDAahhhodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS8wDQYJKoZIhvcNAQEFBQADggEB +ACXj6SB59KRJPenn6gUdGEqcta97U769SATyiQ87i9er64qLwvIGLMa3o2Rcgl2Y +kghUeyLdN/EXyFBYA8L8uvZREPoc7EZukpT/ZDLXy9i2S0jkOxvF2fD/XLbcjGjM +iEYG1/6ASw0ri9C0k4oDDoJLCoeH9++yqF7SFCCMcDkJqiAGXNb4euDpa8vCCtEQ +CSS+ObZbfkreRt3cNCf5LfCXe9OsTnCfc8Cuq81c0oLaG+SmaLUQNBuToq8e9/Zm ++b+/a3RVjxmkV5OCcGVBxsXNDn54Q6wsdw0TBMcjwoEndzpLS7yWgFbbkq5ZiGpw +Qibb2+CfKuQ+WFV1GkVQmVA= +-----END CERTIFICATE-----` + +const startComIntermediate = `-----BEGIN CERTIFICATE----- +MIIGNDCCBBygAwIBAgIBGDANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW +MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg +Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkwHhcNMDcxMDI0MjA1NDE3WhcNMTcxMDI0MjA1NDE3WjCB +jDELMAkGA1UEBhMCSUwxFjAUBgNVBAoTDVN0YXJ0Q29tIEx0ZC4xKzApBgNVBAsT +IlNlY3VyZSBEaWdpdGFsIENlcnRpZmljYXRlIFNpZ25pbmcxODA2BgNVBAMTL1N0 +YXJ0Q29tIENsYXNzIDEgUHJpbWFyeSBJbnRlcm1lZGlhdGUgU2VydmVyIENBMIIB +IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtonGrO8JUngHrJJj0PREGBiE +gFYfka7hh/oyULTTRwbw5gdfcA4Q9x3AzhA2NIVaD5Ksg8asWFI/ujjo/OenJOJA +pgh2wJJuniptTT9uYSAK21ne0n1jsz5G/vohURjXzTCm7QduO3CHtPn66+6CPAVv +kvek3AowHpNz/gfK11+AnSJYUq4G2ouHI2mw5CrY6oPSvfNx23BaKA+vWjhwRRI/ +ME3NO68X5Q/LoKldSKqxYVDLNM08XMML6BDAjJvwAwNi/rJsPnIO7hxDKslIDlc5 +xDEhyBDBLIf+VJVSH1I8MRKbf+fAoKVZ1eKPPvDVqOHXcDGpxLPPr21TLwb0pwID +AQABo4IBrTCCAakwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMCAQYwHQYD +VR0OBBYEFOtCNNCYsKuf9BtrCPfMZC7vDixFMB8GA1UdIwQYMBaAFE4L7xqkQFul +F2mHMMo0aEPQQa7yMGYGCCsGAQUFBwEBBFowWDAnBggrBgEFBQcwAYYbaHR0cDov +L29jc3Auc3RhcnRzc2wuY29tL2NhMC0GCCsGAQUFBzAChiFodHRwOi8vd3d3LnN0 +YXJ0c3NsLmNvbS9zZnNjYS5jcnQwWwYDVR0fBFQwUjAnoCWgI4YhaHR0cDovL3d3 +dy5zdGFydHNzbC5jb20vc2ZzY2EuY3JsMCegJaAjhiFodHRwOi8vY3JsLnN0YXJ0 +c3NsLmNvbS9zZnNjYS5jcmwwgYAGA1UdIAR5MHcwdQYLKwYBBAGBtTcBAgEwZjAu +BggrBgEFBQcCARYiaHR0cDovL3d3dy5zdGFydHNzbC5jb20vcG9saWN5LnBkZjA0 +BggrBgEFBQcCARYoaHR0cDovL3d3dy5zdGFydHNzbC5jb20vaW50ZXJtZWRpYXRl +LnBkZjANBgkqhkiG9w0BAQUFAAOCAgEAIQlJPqWIbuALi0jaMU2P91ZXouHTYlfp +tVbzhUV1O+VQHwSL5qBaPucAroXQ+/8gA2TLrQLhxpFy+KNN1t7ozD+hiqLjfDen +xk+PNdb01m4Ge90h2c9W/8swIkn+iQTzheWq8ecf6HWQTd35RvdCNPdFWAwRDYSw +xtpdPvkBnufh2lWVvnQce/xNFE+sflVHfXv0pQ1JHpXo9xLBzP92piVH0PN1Nb6X +t1gW66pceG/sUzCv6gRNzKkC4/C2BBL2MLERPZBOVmTX3DxDX3M570uvh+v2/miI +RHLq0gfGabDBoYvvF0nXYbFFSF87ICHpW7LM9NfpMfULFWE7epTj69m8f5SuauNi +YpaoZHy4h/OZMn6SolK+u/hlz8nyMPyLwcKmltdfieFcNID1j0cHL7SRv7Gifl9L +WtBbnySGBVFaaQNlQ0lxxeBvlDRr9hvYqbBMflPrj0jfyjO1SPo2ShpTpjMM0InN +SRXNiTE8kMBy12VLUjWKRhFEuT2OKGWmPnmeXAhEKa2wNREuIU640ucQPl2Eg7PD +wuTSxv0JS3QJ3fGz0xk+gA2iCxnwOOfFwq/iI9th4p1cbiCJSS4jarJiwUW0n6+L +p/EiO/h94pDQehn7Skzj0n1fSoMD7SfWI55rjbRZotnvbIIp3XUZPD9MEI3vu3Un +0q6Dp6jOW6c= +-----END CERTIFICATE-----` + +const startComRoot = `-----BEGIN CERTIFICATE----- +MIIHyTCCBbGgAwIBAgIBATANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW +MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg +Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkwHhcNMDYwOTE3MTk0NjM2WhcNMzYwOTE3MTk0NjM2WjB9 +MQswCQYDVQQGEwJJTDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMi +U2VjdXJlIERpZ2l0YWwgQ2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3Rh +cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwggIiMA0GCSqGSIb3DQEBAQUA +A4ICDwAwggIKAoICAQDBiNsJvGxGfHiflXu1M5DycmLWwTYgIiRezul38kMKogZk +pMyONvg45iPwbm2xPN1yo4UcodM9tDMr0y+v/uqwQVlntsQGfQqedIXWeUyAN3rf +OQVSWff0G0ZDpNKFhdLDcfN1YjS6LIp/Ho/u7TTQEceWzVI9ujPW3U3eCztKS5/C +Ji/6tRYccjV3yjxd5srhJosaNnZcAdt0FCX+7bWgiA/deMotHweXMAEtcnn6RtYT +Kqi5pquDSR3l8u/d5AGOGAqPY1MWhWKpDhk6zLVmpsJrdAfkK+F2PrRt2PZE4XNi +HzvEvqBTViVsUQn3qqvKv3b9bZvzndu/PWa8DFaqr5hIlTpL36dYUNk4dalb6kMM +Av+Z6+hsTXBbKWWc3apdzK8BMewM69KN6Oqce+Zu9ydmDBpI125C4z/eIT574Q1w ++2OqqGwaVLRcJXrJosmLFqa7LH4XXgVNWG4SHQHuEhANxjJ/GP/89PrNbpHoNkm+ +Gkhpi8KWTRoSsmkXwQqQ1vp5Iki/untp+HDH+no32NgN0nZPV/+Qt+OR0t3vwmC3 +Zzrd/qqc8NSLf3Iizsafl7b4r4qgEKjZ+xjGtrVcUjyJthkqcwEKDwOzEmDyei+B +26Nu/yYwl/WL3YlXtq09s68rxbd2AvCl1iuahhQqcvbjM4xdCUsT37uMdBNSSwID +AQABo4ICUjCCAk4wDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAa4wHQYDVR0OBBYE +FE4L7xqkQFulF2mHMMo0aEPQQa7yMGQGA1UdHwRdMFswLKAqoCiGJmh0dHA6Ly9j +ZXJ0LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMCugKaAnhiVodHRwOi8vY3Js +LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMIIBXQYDVR0gBIIBVDCCAVAwggFM +BgsrBgEEAYG1NwEBATCCATswLwYIKwYBBQUHAgEWI2h0dHA6Ly9jZXJ0LnN0YXJ0 +Y29tLm9yZy9wb2xpY3kucGRmMDUGCCsGAQUFBwIBFilodHRwOi8vY2VydC5zdGFy +dGNvbS5vcmcvaW50ZXJtZWRpYXRlLnBkZjCB0AYIKwYBBQUHAgIwgcMwJxYgU3Rh +cnQgQ29tbWVyY2lhbCAoU3RhcnRDb20pIEx0ZC4wAwIBARqBl0xpbWl0ZWQgTGlh +YmlsaXR5LCByZWFkIHRoZSBzZWN0aW9uICpMZWdhbCBMaW1pdGF0aW9ucyogb2Yg +dGhlIFN0YXJ0Q29tIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFBvbGljeSBhdmFp +bGFibGUgYXQgaHR0cDovL2NlcnQuc3RhcnRjb20ub3JnL3BvbGljeS5wZGYwEQYJ +YIZIAYb4QgEBBAQDAgAHMDgGCWCGSAGG+EIBDQQrFilTdGFydENvbSBGcmVlIFNT +TCBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTANBgkqhkiG9w0BAQUFAAOCAgEAFmyZ +9GYMNPXQhV59CuzaEE44HF7fpiUFS5Eyweg78T3dRAlbB0mKKctmArexmvclmAk8 +jhvh3TaHK0u7aNM5Zj2gJsfyOZEdUauCe37Vzlrk4gNXcGmXCPleWKYK34wGmkUW +FjgKXlf2Ysd6AgXmvB618p70qSmD+LIU424oh0TDkBreOKk8rENNZEXO3SipXPJz +ewT4F+irsfMuXGRuczE6Eri8sxHkfY+BUZo7jYn0TZNmezwD7dOaHZrzZVD1oNB1 +ny+v8OqCQ5j4aZyJecRDjkZy42Q2Eq/3JR44iZB3fsNrarnDy0RLrHiQi+fHLB5L +EUTINFInzQpdn4XBidUaePKVEFMy3YCEZnXZtWgo+2EuvoSoOMCZEoalHmdkrQYu +L6lwhceWD3yJZfWOQ1QOq92lgDmUYMA0yZZwLKMS9R9Ie70cfmu3nZD0Ijuu+Pwq +yvqCUqDvr0tVk+vBtfAii6w0TiYiBKGHLHVKt+V9E9e4DGTANtLJL4YSjCMJwRuC +O3NJo2pXh5Tl1njFmUNj403gdy3hZZlyaQQaRwnmDwFWJPsfvw55qVguucQJAX6V +um0ABj6y6koQOdjQK/W/7HW/lwLFCRsI3FU34oH7N4RDYiDK51ZLZer+bMEkkySh +NOsF/5oirpt9P/FlUQqmMGqz9IgcgA38corog14= +-----END CERTIFICATE-----` diff --git a/src/pkg/crypto/x509/x509.go b/src/pkg/crypto/x509/x509.go index 2a57f8758..f2a039b5a 100644 --- a/src/pkg/crypto/x509/x509.go +++ b/src/pkg/crypto/x509/x509.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package parses X.509-encoded keys and certificates. +// Package x509 parses X.509-encoded keys and certificates. package x509 import ( "asn1" "big" + "bytes" "container/vector" "crypto" "crypto/rsa" @@ -15,7 +16,6 @@ import ( "hash" "io" "os" - "strings" "time" ) @@ -27,6 +27,20 @@ type pkcs1PrivateKey struct { D asn1.RawValue P asn1.RawValue Q asn1.RawValue + // We ignore these values, if present, because rsa will calculate them. + Dp asn1.RawValue "optional" + Dq asn1.RawValue "optional" + Qinv asn1.RawValue "optional" + + AdditionalPrimes []pkcs1AddtionalRSAPrime "optional" +} + +type pkcs1AddtionalRSAPrime struct { + Prime asn1.RawValue + + // We ignore these values because rsa will calculate them. + Exp asn1.RawValue + Coeff asn1.RawValue } // rawValueIsInteger returns true iff the given ASN.1 RawValue is an INTEGER type. @@ -46,6 +60,10 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { return } + if priv.Version > 1 { + return nil, os.ErrorString("x509: unsupported private key version") + } + if !rawValueIsInteger(&priv.N) || !rawValueIsInteger(&priv.D) || !rawValueIsInteger(&priv.P) || @@ -61,26 +79,66 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { } key.D = new(big.Int).SetBytes(priv.D.Bytes) - key.P = new(big.Int).SetBytes(priv.P.Bytes) - key.Q = new(big.Int).SetBytes(priv.Q.Bytes) + key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes)) + key.Primes[0] = new(big.Int).SetBytes(priv.P.Bytes) + key.Primes[1] = new(big.Int).SetBytes(priv.Q.Bytes) + for i, a := range priv.AdditionalPrimes { + if !rawValueIsInteger(&a.Prime) { + return nil, asn1.StructuralError{"tags don't match"} + } + key.Primes[i+2] = new(big.Int).SetBytes(a.Prime.Bytes) + // We ignore the other two values because rsa will calculate + // them as needed. + } err = key.Validate() if err != nil { return nil, err } + key.Precompute() return } +// rawValueForBig returns an asn1.RawValue which represents the given integer. +func rawValueForBig(n *big.Int) asn1.RawValue { + b := n.Bytes() + if n.Sign() >= 0 && len(b) > 0 && b[0]&0x80 != 0 { + // This positive number would be interpreted as a negative + // number in ASN.1 because the MSB is set. + padded := make([]byte, len(b)+1) + copy(padded[1:], b) + b = padded + } + return asn1.RawValue{Tag: 2, Bytes: b} +} + // MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form. func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { + key.Precompute() + + version := 0 + if len(key.Primes) > 2 { + version = 1 + } + priv := pkcs1PrivateKey{ - Version: 1, - N: asn1.RawValue{Tag: 2, Bytes: key.PublicKey.N.Bytes()}, + Version: version, + N: rawValueForBig(key.N), E: key.PublicKey.E, - D: asn1.RawValue{Tag: 2, Bytes: key.D.Bytes()}, - P: asn1.RawValue{Tag: 2, Bytes: key.P.Bytes()}, - Q: asn1.RawValue{Tag: 2, Bytes: key.Q.Bytes()}, + D: rawValueForBig(key.D), + P: rawValueForBig(key.Primes[0]), + Q: rawValueForBig(key.Primes[1]), + Dp: rawValueForBig(key.Precomputed.Dp), + Dq: rawValueForBig(key.Precomputed.Dq), + Qinv: rawValueForBig(key.Precomputed.Qinv), + } + + priv.AdditionalPrimes = make([]pkcs1AddtionalRSAPrime, len(key.Precomputed.CRTValues)) + for i, values := range key.Precomputed.CRTValues { + priv.AdditionalPrimes[i].Prime = rawValueForBig(key.Primes[2+i]) + priv.AdditionalPrimes[i].Exp = rawValueForBig(values.Exp) + priv.AdditionalPrimes[i].Coeff = rawValueForBig(values.Coeff) } b, _ := asn1.Marshal(priv) @@ -397,6 +455,10 @@ func (ConstraintViolationError) String() string { return "invalid signature: parent certificate cannot sign this kind of certificate" } +func (c *Certificate) Equal(other *Certificate) bool { + return bytes.Equal(c.Raw, other.Raw) +} + // CheckSignatureFrom verifies that the signature on c is a valid signature // from parent. func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { @@ -442,63 +504,6 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { return rsa.VerifyPKCS1v15(pub, hashType, digest, c.Signature) } -func matchHostnames(pattern, host string) bool { - if len(pattern) == 0 || len(host) == 0 { - return false - } - - patternParts := strings.Split(pattern, ".", -1) - hostParts := strings.Split(host, ".", -1) - - if len(patternParts) != len(hostParts) { - return false - } - - for i, patternPart := range patternParts { - if patternPart == "*" { - continue - } - if patternPart != hostParts[i] { - return false - } - } - - return true -} - -type HostnameError struct { - Certificate *Certificate - Host string -} - -func (h *HostnameError) String() string { - var valid string - c := h.Certificate - if len(c.DNSNames) > 0 { - valid = strings.Join(c.DNSNames, ", ") - } else { - valid = c.Subject.CommonName - } - return "certificate is valid for " + valid + ", not " + h.Host -} - -// VerifyHostname returns nil if c is a valid certificate for the named host. -// Otherwise it returns an os.Error describing the mismatch. -func (c *Certificate) VerifyHostname(h string) os.Error { - if len(c.DNSNames) > 0 { - for _, match := range c.DNSNames { - if matchHostnames(match, h) { - return nil - } - } - // If Subject Alt Name is given, we ignore the common name. - } else if matchHostnames(c.Subject.CommonName, h) { - return nil - } - - return &HostnameError{c, h} -} - type UnhandledCriticalExtension struct{} func (h UnhandledCriticalExtension) String() string { diff --git a/src/pkg/crypto/x509/x509_test.go b/src/pkg/crypto/x509/x509_test.go index d9511b863..a42113add 100644 --- a/src/pkg/crypto/x509/x509_test.go +++ b/src/pkg/crypto/x509/x509_test.go @@ -20,12 +20,13 @@ func TestParsePKCS1PrivateKey(t *testing.T) { priv, err := ParsePKCS1PrivateKey(block.Bytes) if err != nil { t.Errorf("Failed to parse private key: %s", err) + return } if priv.PublicKey.N.Cmp(rsaPrivateKey.PublicKey.N) != 0 || priv.PublicKey.E != rsaPrivateKey.PublicKey.E || priv.D.Cmp(rsaPrivateKey.D) != 0 || - priv.P.Cmp(rsaPrivateKey.P) != 0 || - priv.Q.Cmp(rsaPrivateKey.Q) != 0 { + priv.Primes[0].Cmp(rsaPrivateKey.Primes[0]) != 0 || + priv.Primes[1].Cmp(rsaPrivateKey.Primes[1]) != 0 { t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey) } } @@ -47,14 +48,54 @@ func bigFromString(s string) *big.Int { return ret } +func fromBase10(base10 string) *big.Int { + i := new(big.Int) + i.SetString(base10, 10) + return i +} + var rsaPrivateKey = &rsa.PrivateKey{ PublicKey: rsa.PublicKey{ N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), E: 65537, }, D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + Primes: []*big.Int{ + bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + }, +} + +func TestMarshalRSAPrivateKey(t *testing.T) { + priv := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: fromBase10("16346378922382193400538269749936049106320265317511766357599732575277382844051791096569333808598921852351577762718529818072849191122419410612033592401403764925096136759934497687765453905884149505175426053037420486697072448609022753683683718057795566811401938833367954642951433473337066311978821180526439641496973296037000052546108507805269279414789035461158073156772151892452251106173507240488993608650881929629163465099476849643165682709047462010581308719577053905787496296934240246311806555924593059995202856826239801816771116902778517096212527979497399966526283516447337775509777558018145573127308919204297111496233"), + E: 3, + }, + D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), + Primes: []*big.Int{ + fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), + fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), + fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + }, + } + + derBytes := MarshalPKCS1PrivateKey(priv) + + priv2, err := ParsePKCS1PrivateKey(derBytes) + if err != nil { + t.Errorf("error parsing serialized key: %s", err) + return + } + if priv.PublicKey.N.Cmp(priv2.PublicKey.N) != 0 || + priv.PublicKey.E != priv2.PublicKey.E || + priv.D.Cmp(priv2.D) != 0 || + len(priv2.Primes) != 3 || + priv.Primes[0].Cmp(priv2.Primes[0]) != 0 || + priv.Primes[1].Cmp(priv2.Primes[1]) != 0 || + priv.Primes[2].Cmp(priv2.Primes[2]) != 0 { + t.Errorf("got:%+v want:%+v", priv, priv2) + } } type matchHostnamesTest struct { diff --git a/src/pkg/crypto/xtea/cipher.go b/src/pkg/crypto/xtea/cipher.go index b0fa2a184..f2a5da003 100644 --- a/src/pkg/crypto/xtea/cipher.go +++ b/src/pkg/crypto/xtea/cipher.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements XTEA encryption, as defined in Needham and -// Wheeler's 1997 technical report, "Tea extensions." +// Package xtea implements XTEA encryption, as defined in Needham and Wheeler's +// 1997 technical report, "Tea extensions." package xtea // For details, see http://www.cix.co.uk/~klockstone/xtea.pdf diff --git a/src/pkg/debug/dwarf/open.go b/src/pkg/debug/dwarf/open.go index cb009e0e0..d9525f788 100644 --- a/src/pkg/debug/dwarf/open.go +++ b/src/pkg/debug/dwarf/open.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides access to DWARF debugging information -// loaded from executable files, as defined in the DWARF 2.0 Standard -// at http://dwarfstd.org/doc/dwarf-2.0.0.pdf +// Package dwarf provides access to DWARF debugging information loaded from +// executable files, as defined in the DWARF 2.0 Standard at +// http://dwarfstd.org/doc/dwarf-2.0.0.pdf package dwarf import ( diff --git a/src/pkg/debug/elf/elf.go b/src/pkg/debug/elf/elf.go index 74e979986..5d45b2486 100644 --- a/src/pkg/debug/elf/elf.go +++ b/src/pkg/debug/elf/elf.go @@ -330,29 +330,35 @@ func (i SectionIndex) GoString() string { return stringName(uint32(i), shnString type SectionType uint32 const ( - SHT_NULL SectionType = 0 /* inactive */ - SHT_PROGBITS SectionType = 1 /* program defined information */ - SHT_SYMTAB SectionType = 2 /* symbol table section */ - SHT_STRTAB SectionType = 3 /* string table section */ - SHT_RELA SectionType = 4 /* relocation section with addends */ - SHT_HASH SectionType = 5 /* symbol hash table section */ - SHT_DYNAMIC SectionType = 6 /* dynamic section */ - SHT_NOTE SectionType = 7 /* note section */ - SHT_NOBITS SectionType = 8 /* no space section */ - SHT_REL SectionType = 9 /* relocation section - no addends */ - SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ - SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ - SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ - SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ - SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ - SHT_GROUP SectionType = 17 /* Section group. */ - SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ - SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ - SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ - SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ - SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ - SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ - SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ + SHT_NULL SectionType = 0 /* inactive */ + SHT_PROGBITS SectionType = 1 /* program defined information */ + SHT_SYMTAB SectionType = 2 /* symbol table section */ + SHT_STRTAB SectionType = 3 /* string table section */ + SHT_RELA SectionType = 4 /* relocation section with addends */ + SHT_HASH SectionType = 5 /* symbol hash table section */ + SHT_DYNAMIC SectionType = 6 /* dynamic section */ + SHT_NOTE SectionType = 7 /* note section */ + SHT_NOBITS SectionType = 8 /* no space section */ + SHT_REL SectionType = 9 /* relocation section - no addends */ + SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ + SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ + SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ + SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ + SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ + SHT_GROUP SectionType = 17 /* Section group. */ + SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ + SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ + SHT_GNU_ATTRIBUTES SectionType = 0x6ffffff5 /* GNU object attributes */ + SHT_GNU_HASH SectionType = 0x6ffffff6 /* GNU hash table */ + SHT_GNU_LIBLIST SectionType = 0x6ffffff7 /* GNU prelink library list */ + SHT_GNU_VERDEF SectionType = 0x6ffffffd /* GNU version definition section */ + SHT_GNU_VERNEED SectionType = 0x6ffffffe /* GNU version needs section */ + SHT_GNU_VERSYM SectionType = 0x6fffffff /* GNU version symbol table */ + SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ + SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ + SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ + SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ + SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ ) var shtStrings = []intName{ @@ -374,7 +380,12 @@ var shtStrings = []intName{ {17, "SHT_GROUP"}, {18, "SHT_SYMTAB_SHNDX"}, {0x60000000, "SHT_LOOS"}, - {0x6fffffff, "SHT_HIOS"}, + {0x6ffffff5, "SHT_GNU_ATTRIBUTES"}, + {0x6ffffff6, "SHT_GNU_HASH"}, + {0x6ffffff7, "SHT_GNU_LIBLIST"}, + {0x6ffffffd, "SHT_GNU_VERDEF"}, + {0x6ffffffe, "SHT_GNU_VERNEED"}, + {0x6fffffff, "SHT_GNU_VERSYM"}, {0x70000000, "SHT_LOPROC"}, {0x7fffffff, "SHT_HIPROC"}, {0x80000000, "SHT_LOUSER"}, @@ -518,6 +529,9 @@ const ( DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */ DT_LOOS DynTag = 0x6000000d /* First OS-specific */ DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */ + DT_VERSYM DynTag = 0x6ffffff0 + DT_VERNEED DynTag = 0x6ffffffe + DT_VERNEEDNUM DynTag = 0x6fffffff DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */ DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */ ) @@ -559,6 +573,9 @@ var dtStrings = []intName{ {33, "DT_PREINIT_ARRAYSZ"}, {0x6000000d, "DT_LOOS"}, {0x6ffff000, "DT_HIOS"}, + {0x6ffffff0, "DT_VERSYM"}, + {0x6ffffffe, "DT_VERNEED"}, + {0x6fffffff, "DT_VERNEEDNUM"}, {0x70000000, "DT_LOPROC"}, {0x7fffffff, "DT_HIPROC"}, } diff --git a/src/pkg/debug/elf/file.go b/src/pkg/debug/elf/file.go index 6fdcda6d4..9ae8b413d 100644 --- a/src/pkg/debug/elf/file.go +++ b/src/pkg/debug/elf/file.go @@ -35,9 +35,11 @@ type FileHeader struct { // A File represents an open ELF file. type File struct { FileHeader - Sections []*Section - Progs []*Prog - closer io.Closer + Sections []*Section + Progs []*Prog + closer io.Closer + gnuNeed []verneed + gnuVersym []byte } // A SectionHeader represents a single ELF section header. @@ -329,8 +331,8 @@ func NewFile(r io.ReaderAt) (*File, os.Error) { } // getSymbols returns a slice of Symbols from parsing the symbol table -// with the given type. -func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { +// with the given type, along with the associated string table. +func (f *File) getSymbols(typ SectionType) ([]Symbol, []byte, os.Error) { switch f.Class { case ELFCLASS64: return f.getSymbols64(typ) @@ -339,27 +341,27 @@ func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { return f.getSymbols32(typ) } - return nil, os.ErrorString("not implemented") + return nil, nil, os.ErrorString("not implemented") } -func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { +func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, os.Error) { symtabSection := f.SectionByType(typ) if symtabSection == nil { - return nil, os.ErrorString("no symbol section") + return nil, nil, os.ErrorString("no symbol section") } data, err := symtabSection.Data() if err != nil { - return nil, os.ErrorString("cannot load symbol section") + return nil, nil, os.ErrorString("cannot load symbol section") } symtab := bytes.NewBuffer(data) if symtab.Len()%Sym32Size != 0 { - return nil, os.ErrorString("length of symbol section is not a multiple of SymSize") + return nil, nil, os.ErrorString("length of symbol section is not a multiple of SymSize") } strdata, err := f.stringTable(symtabSection.Link) if err != nil { - return nil, os.ErrorString("cannot load string table section") + return nil, nil, os.ErrorString("cannot load string table section") } // The first entry is all zeros. @@ -382,27 +384,27 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { i++ } - return symbols, nil + return symbols, strdata, nil } -func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { +func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, os.Error) { symtabSection := f.SectionByType(typ) if symtabSection == nil { - return nil, os.ErrorString("no symbol section") + return nil, nil, os.ErrorString("no symbol section") } data, err := symtabSection.Data() if err != nil { - return nil, os.ErrorString("cannot load symbol section") + return nil, nil, os.ErrorString("cannot load symbol section") } symtab := bytes.NewBuffer(data) if symtab.Len()%Sym64Size != 0 { - return nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") + return nil, nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") } strdata, err := f.stringTable(symtabSection.Link) if err != nil { - return nil, os.ErrorString("cannot load string table section") + return nil, nil, os.ErrorString("cannot load string table section") } // The first entry is all zeros. @@ -425,7 +427,7 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { i++ } - return symbols, nil + return symbols, strdata, nil } // getString extracts a string from an ELF string table. @@ -468,7 +470,7 @@ func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) os.Error { return os.ErrorString("length of relocation section is not a multiple of Sym64Size") } - symbols, err := f.getSymbols(SHT_SYMTAB) + symbols, _, err := f.getSymbols(SHT_SYMTAB) if err != nil { return err } @@ -544,24 +546,123 @@ func (f *File) DWARF() (*dwarf.Data, os.Error) { return dwarf.New(abbrev, nil, nil, info, nil, nil, nil, str) } +type ImportedSymbol struct { + Name string + Version string + Library string +} + // ImportedSymbols returns the names of all symbols // referred to by the binary f that are expected to be // satisfied by other libraries at dynamic load time. // It does not return weak symbols. -func (f *File) ImportedSymbols() ([]string, os.Error) { - sym, err := f.getSymbols(SHT_DYNSYM) +func (f *File) ImportedSymbols() ([]ImportedSymbol, os.Error) { + sym, str, err := f.getSymbols(SHT_DYNSYM) if err != nil { return nil, err } - var all []string - for _, s := range sym { + f.gnuVersionInit(str) + var all []ImportedSymbol + for i, s := range sym { if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF { - all = append(all, s.Name) + all = append(all, ImportedSymbol{Name: s.Name}) + f.gnuVersion(i, &all[len(all)-1]) } } return all, nil } +type verneed struct { + File string + Name string +} + +// gnuVersionInit parses the GNU version tables +// for use by calls to gnuVersion. +func (f *File) gnuVersionInit(str []byte) { + // Accumulate verneed information. + vn := f.SectionByType(SHT_GNU_VERNEED) + if vn == nil { + return + } + d, _ := vn.Data() + + var need []verneed + i := 0 + for { + if i+16 > len(d) { + break + } + vers := f.ByteOrder.Uint16(d[i : i+2]) + if vers != 1 { + break + } + cnt := f.ByteOrder.Uint16(d[i+2 : i+4]) + fileoff := f.ByteOrder.Uint32(d[i+4 : i+8]) + aux := f.ByteOrder.Uint32(d[i+8 : i+12]) + next := f.ByteOrder.Uint32(d[i+12 : i+16]) + file, _ := getString(str, int(fileoff)) + + var name string + j := i + int(aux) + for c := 0; c < int(cnt); c++ { + if j+16 > len(d) { + break + } + // hash := f.ByteOrder.Uint32(d[j:j+4]) + // flags := f.ByteOrder.Uint16(d[j+4:j+6]) + other := f.ByteOrder.Uint16(d[j+6 : j+8]) + nameoff := f.ByteOrder.Uint32(d[j+8 : j+12]) + next := f.ByteOrder.Uint32(d[j+12 : j+16]) + name, _ = getString(str, int(nameoff)) + ndx := int(other) + if ndx >= len(need) { + a := make([]verneed, 2*(ndx+1)) + copy(a, need) + need = a + } + + need[ndx] = verneed{file, name} + if next == 0 { + break + } + j += int(next) + } + + if next == 0 { + break + } + i += int(next) + } + + // Versym parallels symbol table, indexing into verneed. + vs := f.SectionByType(SHT_GNU_VERSYM) + if vs == nil { + return + } + d, _ = vs.Data() + + f.gnuNeed = need + f.gnuVersym = d +} + +// gnuVersion adds Library and Version information to sym, +// which came from offset i of the symbol table. +func (f *File) gnuVersion(i int, sym *ImportedSymbol) { + // Each entry is two bytes; skip undef entry at beginning. + i = (i + 1) * 2 + if i >= len(f.gnuVersym) { + return + } + j := int(f.ByteOrder.Uint16(f.gnuVersym[i:])) + if j < 2 || j >= len(f.gnuNeed) { + return + } + n := &f.gnuNeed[j] + sym.Library = n.File + sym.Version = n.Name +} + // ImportedLibraries returns the names of all libraries // referred to by the binary f that are expected to be // linked with the binary at dynamic link time. diff --git a/src/pkg/ebnf/ebnf.go b/src/pkg/ebnf/ebnf.go index e5aabd582..7918c4593 100644 --- a/src/pkg/ebnf/ebnf.go +++ b/src/pkg/ebnf/ebnf.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A library for EBNF grammars. The input is text ([]byte) satisfying -// the following grammar (represented itself in EBNF): +// Package ebnf is a library for EBNF grammars. The input is text ([]byte) +// satisfying the following grammar (represented itself in EBNF): // // Production = name "=" Expression "." . // Expression = Alternative { "|" Alternative } . diff --git a/src/pkg/encoding/binary/binary.go b/src/pkg/encoding/binary/binary.go index a4b390701..a01d0e024 100644 --- a/src/pkg/encoding/binary/binary.go +++ b/src/pkg/encoding/binary/binary.go @@ -126,7 +126,7 @@ func (bigEndian) GoString() string { return "binary.BigEndian" } // and written to successive fields of the data. func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { var v reflect.Value - switch d := reflect.NewValue(data); d.Kind() { + switch d := reflect.ValueOf(data); d.Kind() { case reflect.Ptr: v = d.Elem() case reflect.Slice: @@ -155,7 +155,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { // Bytes written to w are encoded using the specified byte order // and read from successive fields of the data. func Write(w io.Writer, order ByteOrder, data interface{}) os.Error { - v := reflect.Indirect(reflect.NewValue(data)) + v := reflect.Indirect(reflect.ValueOf(data)) size := TotalSize(v) if size < 0 { return os.NewError("binary.Write: invalid type " + v.Type().String()) diff --git a/src/pkg/encoding/binary/binary_test.go b/src/pkg/encoding/binary/binary_test.go index d1fc1bfd3..7857c68d3 100644 --- a/src/pkg/encoding/binary/binary_test.go +++ b/src/pkg/encoding/binary/binary_test.go @@ -152,7 +152,7 @@ func TestWriteT(t *testing.T) { t.Errorf("WriteT: have nil, want non-nil") } - tv := reflect.Indirect(reflect.NewValue(ts)) + tv := reflect.Indirect(reflect.ValueOf(ts)) for i, n := 0, tv.NumField(); i < n; i++ { err = Write(buf, BigEndian, tv.Field(i).Interface()) if err == nil { diff --git a/src/pkg/encoding/hex/hex.go b/src/pkg/encoding/hex/hex.go index 292d917eb..891de1861 100644 --- a/src/pkg/encoding/hex/hex.go +++ b/src/pkg/encoding/hex/hex.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements hexadecimal encoding and decoding. +// Package hex implements hexadecimal encoding and decoding. package hex import ( diff --git a/src/pkg/encoding/line/line.go b/src/pkg/encoding/line/line.go index f46ce1c83..123962b1f 100644 --- a/src/pkg/encoding/line/line.go +++ b/src/pkg/encoding/line/line.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The line package implements a Reader that reads lines delimited by '\n' or ' \r\n'. +// Package line implements a Reader that reads lines delimited by '\n' or +// ' \r\n'. package line import ( diff --git a/src/pkg/encoding/pem/pem.go b/src/pkg/encoding/pem/pem.go index 5653aeb77..44e3d0ad0 100644 --- a/src/pkg/encoding/pem/pem.go +++ b/src/pkg/encoding/pem/pem.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the PEM data encoding, which originated in Privacy +// Package pem implements the PEM data encoding, which originated in Privacy // Enhanced Mail. The most common use of PEM encoding today is in TLS keys and // certificates. See RFC 1421. package pem diff --git a/src/pkg/exec/exec.go b/src/pkg/exec/exec.go index 5398eb8e0..043f84728 100644 --- a/src/pkg/exec/exec.go +++ b/src/pkg/exec/exec.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The exec package runs external commands. It wraps os.StartProcess -// to make it easier to remap stdin and stdout, connect I/O with pipes, -// and do other adjustments. +// Package exec runs external commands. It wraps os.StartProcess to make it +// easier to remap stdin and stdout, connect I/O with pipes, and do other +// adjustments. package exec // BUG(r): This package should be made even easier to use or merged into os. diff --git a/src/pkg/exec/exec_test.go b/src/pkg/exec/exec_test.go index 5e37b99ee..eb8cd5fec 100644 --- a/src/pkg/exec/exec_test.go +++ b/src/pkg/exec/exec_test.go @@ -9,19 +9,14 @@ import ( "io/ioutil" "testing" "os" - "runtime" ) func run(argv []string, stdin, stdout, stderr int) (p *Cmd, err os.Error) { - if runtime.GOOS == "windows" { - argv = append([]string{"cmd", "/c"}, argv...) - } exe, err := LookPath(argv[0]) if err != nil { return nil, err } - p, err = Run(exe, argv, nil, "", stdin, stdout, stderr) - return p, err + return Run(exe, argv, nil, "", stdin, stdout, stderr) } func TestRunCat(t *testing.T) { diff --git a/src/pkg/exp/datafmt/datafmt.go b/src/pkg/exp/datafmt/datafmt.go index 6d816fc2d..a8efdc58f 100644 --- a/src/pkg/exp/datafmt/datafmt.go +++ b/src/pkg/exp/datafmt/datafmt.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -/* The datafmt package implements syntax-directed, type-driven formatting +/* Package datafmt implements syntax-directed, type-driven formatting of arbitrary data structures. Formatting a data structure consists of two phases: first, a parser reads a format specification and builds a "compiled" format. Then, the format can be applied repeatedly to @@ -671,7 +671,7 @@ func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { go func() { for _, v := range args { - fld := reflect.NewValue(v) + fld := reflect.ValueOf(v) if !fld.IsValid() { errors <- os.NewError("nil argument") return diff --git a/src/pkg/exp/draw/x11/conn.go b/src/pkg/exp/draw/x11/conn.go index 53294af15..81c67267d 100644 --- a/src/pkg/exp/draw/x11/conn.go +++ b/src/pkg/exp/draw/x11/conn.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements an X11 backend for the exp/draw package. +// Package x11 implements an X11 backend for the exp/draw package. // // The X protocol specification is at ftp://ftp.x.org/pub/X11R7.0/doc/PDF/proto.pdf. // A summary of the wire format can be found in XCB's xproto.xml. diff --git a/src/pkg/exp/eval/bridge.go b/src/pkg/exp/eval/bridge.go index d1efa2eb6..f31d9ab9b 100644 --- a/src/pkg/exp/eval/bridge.go +++ b/src/pkg/exp/eval/bridge.go @@ -128,7 +128,7 @@ func TypeFromNative(t reflect.Type) Type { } // TypeOfNative returns the interpreter Type of a regular Go value. -func TypeOfNative(v interface{}) Type { return TypeFromNative(reflect.Typeof(v)) } +func TypeOfNative(v interface{}) Type { return TypeFromNative(reflect.TypeOf(v)) } /* * Function bridging diff --git a/src/pkg/exp/eval/type.go b/src/pkg/exp/eval/type.go index 0d6dfe923..8a93d8a6c 100644 --- a/src/pkg/exp/eval/type.go +++ b/src/pkg/exp/eval/type.go @@ -86,7 +86,7 @@ func hashTypeArray(key []Type) uintptr { if t == nil { continue } - addr := reflect.NewValue(t).Pointer() + addr := reflect.ValueOf(t).Pointer() hash ^= addr } return hash diff --git a/src/pkg/exp/eval/world.go b/src/pkg/exp/eval/world.go index 02d18bd79..a5f6ac7e5 100644 --- a/src/pkg/exp/eval/world.go +++ b/src/pkg/exp/eval/world.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package is the beginning of an interpreter for Go. +// Package eval is the beginning of an interpreter for Go. // It can run simple Go programs but does not implement // interface values or packages. package eval diff --git a/src/pkg/exp/ogle/cmd.go b/src/pkg/exp/ogle/cmd.go index 813d3a875..a8db523ea 100644 --- a/src/pkg/exp/ogle/cmd.go +++ b/src/pkg/exp/ogle/cmd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Ogle is the beginning of a debugger for Go. +// Package ogle is the beginning of a debugger for Go. package ogle import ( diff --git a/src/pkg/exp/ogle/process.go b/src/pkg/exp/ogle/process.go index e4f44b6fc..7c803b3a2 100644 --- a/src/pkg/exp/ogle/process.go +++ b/src/pkg/exp/ogle/process.go @@ -226,7 +226,7 @@ func (p *Process) bootstrap() { p.runtime.G = newManualType(eval.TypeOfNative(rt1G{}), p.Arch) // Get addresses of type.*runtime.XType for discrimination. - rtv := reflect.Indirect(reflect.NewValue(&p.runtime)) + rtv := reflect.Indirect(reflect.ValueOf(&p.runtime)) rtvt := rtv.Type() for i := 0; i < rtv.NumField(); i++ { n := rtvt.Field(i).Name diff --git a/src/pkg/exp/ogle/rruntime.go b/src/pkg/exp/ogle/rruntime.go index e234f3186..950418b53 100644 --- a/src/pkg/exp/ogle/rruntime.go +++ b/src/pkg/exp/ogle/rruntime.go @@ -236,9 +236,9 @@ type runtimeValues struct { // indexes gathered from the remoteTypes recorded in a runtimeValues // structure. func fillRuntimeIndexes(runtime *runtimeValues, out *runtimeIndexes) { - outv := reflect.Indirect(reflect.NewValue(out)) + outv := reflect.Indirect(reflect.ValueOf(out)) outt := outv.Type() - runtimev := reflect.Indirect(reflect.NewValue(runtime)) + runtimev := reflect.Indirect(reflect.ValueOf(runtime)) // out contains fields corresponding to each runtime type for i := 0; i < outt.NumField(); i++ { diff --git a/src/pkg/expvar/expvar.go b/src/pkg/expvar/expvar.go index ed6cff78d..7123d4b0f 100644 --- a/src/pkg/expvar/expvar.go +++ b/src/pkg/expvar/expvar.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The expvar package provides a standardized interface to public variables, -// such as operation counters in servers. It exposes these variables via -// HTTP at /debug/vars in JSON format. +// Package expvar provides a standardized interface to public variables, such +// as operation counters in servers. It exposes these variables via HTTP at +// /debug/vars in JSON format. // // Operations to set or modify these public variables are atomic. // @@ -180,23 +180,14 @@ func (v *String) String() string { return strconv.Quote(v.s) } func (v *String) Set(value string) { v.s = value } -// IntFunc wraps a func() int64 to create a value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type IntFunc func() int64 +// Func implements Var by calling the function +// and formatting the returned value using JSON. +type Func func() interface{} -func (v IntFunc) String() string { return strconv.Itoa64(v()) } - -// FloatFunc wraps a func() float64 to create a value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type FloatFunc func() float64 - -func (v FloatFunc) String() string { return strconv.Ftoa64(v(), 'g', -1) } - -// StringFunc wraps a func() string to create value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type StringFunc func() string - -func (f StringFunc) String() string { return strconv.Quote(f()) } +func (f Func) String() string { + v, _ := json.Marshal(f()) + return string(v) +} // All published variables. @@ -282,18 +273,16 @@ func expvarHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "\n}\n") } -func memstats() string { - b, _ := json.MarshalIndent(&runtime.MemStats, "", "\t") - return string(b) +func cmdline() interface{} { + return os.Args } -func cmdline() string { - b, _ := json.Marshal(os.Args) - return string(b) +func memstats() interface{} { + return runtime.MemStats } func init() { http.Handle("/debug/vars", http.HandlerFunc(expvarHandler)) - Publish("cmdline", StringFunc(cmdline)) - Publish("memstats", StringFunc(memstats)) + Publish("cmdline", Func(cmdline)) + Publish("memstats", Func(memstats)) } diff --git a/src/pkg/expvar/expvar_test.go b/src/pkg/expvar/expvar_test.go index a8b1a96a9..94926d9f8 100644 --- a/src/pkg/expvar/expvar_test.go +++ b/src/pkg/expvar/expvar_test.go @@ -114,41 +114,15 @@ func TestMapCounter(t *testing.T) { } } -func TestIntFunc(t *testing.T) { - x := int64(4) - ix := IntFunc(func() int64 { return x }) - if s := ix.String(); s != "4" { - t.Errorf("ix.String() = %v, want 4", s) +func TestFunc(t *testing.T) { + var x interface{} = []string{"a", "b"} + f := Func(func() interface{} { return x }) + if s, exp := f.String(), `["a","b"]`; s != exp { + t.Errorf(`f.String() = %q, want %q`, s, exp) } - x++ - if s := ix.String(); s != "5" { - t.Errorf("ix.String() = %v, want 5", s) - } -} - -func TestFloatFunc(t *testing.T) { - x := 8.5 - ix := FloatFunc(func() float64 { return x }) - if s := ix.String(); s != "8.5" { - t.Errorf("ix.String() = %v, want 3.14", s) - } - - x -= 1.25 - if s := ix.String(); s != "7.25" { - t.Errorf("ix.String() = %v, want 4.34", s) - } -} - -func TestStringFunc(t *testing.T) { - x := "hello" - sx := StringFunc(func() string { return x }) - if s, exp := sx.String(), `"hello"`; s != exp { - t.Errorf(`sx.String() = %q, want %q`, s, exp) - } - - x = "goodbye" - if s, exp := sx.String(), `"goodbye"`; s != exp { - t.Errorf(`sx.String() = %q, want %q`, s, exp) + x = 17 + if s, exp := f.String(), `17`; s != exp { + t.Errorf(`f.String() = %q, want %q`, s, exp) } } diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go index 19a310455..9ed20e06b 100644 --- a/src/pkg/flag/flag.go +++ b/src/pkg/flag/flag.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The flag package implements command-line flag parsing. + Package flag implements command-line flag parsing. Usage: diff --git a/src/pkg/fmt/doc.go b/src/pkg/fmt/doc.go index 77ee62bb1..e4d4f1844 100644 --- a/src/pkg/fmt/doc.go +++ b/src/pkg/fmt/doc.go @@ -27,7 +27,7 @@ %o base 8 %x base 16, with lower-case letters for a-f %X base 16, with upper-case letters for A-F - %U Unicode format: U+1234; same as "U+%x" with 4 digits default + %U Unicode format: U+1234; same as "U+%0.4X" Floating-point and complex constituents: %b decimalless scientific notation with exponent a power of two, in the manner of strconv.Ftoa32, e.g. -123456p-78 diff --git a/src/pkg/fmt/print.go b/src/pkg/fmt/print.go index 7fca6afe4..10e0fe7c8 100644 --- a/src/pkg/fmt/print.go +++ b/src/pkg/fmt/print.go @@ -260,7 +260,7 @@ func getField(v reflect.Value, i int) reflect.Value { val := v.Field(i) if i := val; i.Kind() == reflect.Interface { if inter := i.Interface(); inter != nil { - return reflect.NewValue(inter) + return reflect.ValueOf(inter) } } return val @@ -284,7 +284,7 @@ func (p *pp) unknownType(v interface{}) { return } p.buf.WriteByte('?') - p.buf.WriteString(reflect.Typeof(v).String()) + p.buf.WriteString(reflect.TypeOf(v).String()) p.buf.WriteByte('?') } @@ -296,7 +296,7 @@ func (p *pp) badVerb(verb int, val interface{}) { if val == nil { p.buf.Write(nilAngleBytes) } else { - p.buf.WriteString(reflect.Typeof(val).String()) + p.buf.WriteString(reflect.TypeOf(val).String()) p.add('=') p.printField(val, 'v', false, false, 0) } @@ -527,7 +527,7 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } if goSyntax { p.add('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.add(')') p.add('(') if u == 0 { @@ -542,10 +542,10 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } var ( - intBits = reflect.Typeof(0).Bits() - floatBits = reflect.Typeof(0.0).Bits() - complexBits = reflect.Typeof(1i).Bits() - uintptrBits = reflect.Typeof(uintptr(0)).Bits() + intBits = reflect.TypeOf(0).Bits() + floatBits = reflect.TypeOf(0.0).Bits() + complexBits = reflect.TypeOf(1i).Bits() + uintptrBits = reflect.TypeOf(uintptr(0)).Bits() ) func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) { @@ -562,10 +562,10 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth // %T (the value's type) and %p (its address) are special; we always do them first. switch verb { case 'T': - p.printField(reflect.Typeof(field).String(), 's', false, false, 0) + p.printField(reflect.TypeOf(field).String(), 's', false, false, 0) return false case 'p': - p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax) + p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax) return false } // Is it a Formatter? @@ -653,7 +653,7 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth } // Need to use reflection - value := reflect.NewValue(field) + value := reflect.ValueOf(field) BigSwitch: switch f := value; f.Kind() { @@ -704,7 +704,7 @@ BigSwitch: } case reflect.Struct: if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) } p.add('{') v := f @@ -730,7 +730,7 @@ BigSwitch: value := f.Elem() if !value.IsValid() { if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.Write(nilParenBytes) } else { p.buf.Write(nilAngleBytes) @@ -756,7 +756,7 @@ BigSwitch: return verb == 's' } if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('{') } else { p.buf.WriteByte('[') @@ -794,7 +794,7 @@ BigSwitch: } if goSyntax { p.buf.WriteByte('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte(')') p.buf.WriteByte('(') if v == 0 { @@ -915,7 +915,7 @@ func (p *pp) doPrintf(format string, a []interface{}) { for ; fieldnum < len(a); fieldnum++ { field := a[fieldnum] if field != nil { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('=') } p.printField(field, 'v', false, false, 0) @@ -934,7 +934,7 @@ func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) { // always add spaces if we're doing println field := a[fieldnum] if fieldnum > 0 { - isString := field != nil && reflect.Typeof(field).Kind() == reflect.String + isString := field != nil && reflect.TypeOf(field).Kind() == reflect.String if addspace || !isString && !prevString { p.buf.WriteByte(' ') } diff --git a/src/pkg/fmt/scan.go b/src/pkg/fmt/scan.go index b1b3975e2..42bc52c92 100644 --- a/src/pkg/fmt/scan.go +++ b/src/pkg/fmt/scan.go @@ -423,7 +423,7 @@ func (s *ss) token(skipSpace bool, f func(int) bool) []byte { // typeError indicates that the type of the operand did not match the format func (s *ss) typeError(field interface{}, expected string) { - s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String()) + s.errorString("expected field of type pointer to " + expected + "; found " + reflect.TypeOf(field).String()) } var complexError = os.ErrorString("syntax error scanning complex number") @@ -908,7 +908,7 @@ func (s *ss) scanOne(verb int, field interface{}) { // If we scanned to bytes, the slice would point at the buffer. *v = []byte(s.convertString(verb)) default: - val := reflect.NewValue(v) + val := reflect.ValueOf(v) ptr := val if ptr.Kind() != reflect.Ptr { s.errorString("Scan: type not a pointer: " + val.Type().String()) diff --git a/src/pkg/fmt/scan_test.go b/src/pkg/fmt/scan_test.go index b8b3ac975..da13eb2d1 100644 --- a/src/pkg/fmt/scan_test.go +++ b/src/pkg/fmt/scan_test.go @@ -370,7 +370,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{} continue } // The incoming value may be a pointer - v := reflect.NewValue(test.in) + v := reflect.ValueOf(test.in) if p := v; p.Kind() == reflect.Ptr { v = p.Elem() } @@ -409,7 +409,7 @@ func TestScanf(t *testing.T) { continue } // The incoming value may be a pointer - v := reflect.NewValue(test.in) + v := reflect.ValueOf(test.in) if p := v; p.Kind() == reflect.Ptr { v = p.Elem() } @@ -486,7 +486,7 @@ func TestInf(t *testing.T) { } func testScanfMulti(name string, t *testing.T) { - sliceType := reflect.Typeof(make([]interface{}, 1)) + sliceType := reflect.TypeOf(make([]interface{}, 1)) for _, test := range multiTests { var r io.Reader if name == "StringReader" { @@ -513,7 +513,7 @@ func testScanfMulti(name string, t *testing.T) { // Convert the slice of pointers into a slice of values resultVal := reflect.MakeSlice(sliceType, n, n) for i := 0; i < n; i++ { - v := reflect.NewValue(test.in[i]).Elem() + v := reflect.ValueOf(test.in[i]).Elem() resultVal.Index(i).Set(v) } result := resultVal.Interface() @@ -810,7 +810,9 @@ func TestScanInts(t *testing.T) { }) } -const intCount = 1000 +// 800 is small enough to not overflow the stack when using gccgo on a +// platform that does not support split stack. +const intCount = 800 func testScanInts(t *testing.T, scan func(*RecursiveInt, *bytes.Buffer) os.Error) { r := new(RecursiveInt) diff --git a/src/pkg/go/ast/ast.go b/src/pkg/go/ast/ast.go index ed3e2cdd9..2fc1a6032 100644 --- a/src/pkg/go/ast/ast.go +++ b/src/pkg/go/ast/ast.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The AST package declares the types used to represent -// syntax trees for Go packages. +// Package ast declares the types used to represent syntax trees for Go +// packages. // package ast diff --git a/src/pkg/go/ast/print.go b/src/pkg/go/ast/print.go index e6d4e838d..81e1da1d0 100644 --- a/src/pkg/go/ast/print.go +++ b/src/pkg/go/ast/print.go @@ -62,7 +62,7 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i p.printf("nil\n") return } - p.print(reflect.NewValue(x)) + p.print(reflect.ValueOf(x)) p.printf("\n") return diff --git a/src/pkg/go/doc/doc.go b/src/pkg/go/doc/doc.go index e7a8d3f63..29d205d39 100644 --- a/src/pkg/go/doc/doc.go +++ b/src/pkg/go/doc/doc.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The doc package extracts source code documentation from a Go AST. +// Package doc extracts source code documentation from a Go AST. package doc import ( diff --git a/src/pkg/go/parser/parser.go b/src/pkg/go/parser/parser.go index 84a0da6ae..5c57e41d1 100644 --- a/src/pkg/go/parser/parser.go +++ b/src/pkg/go/parser/parser.go @@ -2,10 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A parser for Go source files. Input may be provided in a variety of -// forms (see the various Parse* functions); the output is an abstract -// syntax tree (AST) representing the Go source. The parser is invoked -// through one of the Parse* functions. +// Package parser implements a parser for Go source files. Input may be +// provided in a variety of forms (see the various Parse* functions); the +// output is an abstract syntax tree (AST) representing the Go source. The +// parser is invoked through one of the Parse* functions. // package parser diff --git a/src/pkg/go/printer/printer.go b/src/pkg/go/printer/printer.go index 697a83fa8..01ebf783c 100644 --- a/src/pkg/go/printer/printer.go +++ b/src/pkg/go/printer/printer.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The printer package implements printing of AST nodes. +// Package printer implements printing of AST nodes. package printer import ( diff --git a/src/pkg/go/scanner/scanner.go b/src/pkg/go/scanner/scanner.go index 2f949ad25..07b7454c8 100644 --- a/src/pkg/go/scanner/scanner.go +++ b/src/pkg/go/scanner/scanner.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A scanner for Go source text. Takes a []byte as source which can -// then be tokenized through repeated calls to the Scan function. -// Typical use: +// Package scanner implements a scanner for Go source text. Takes a []byte as +// source which can then be tokenized through repeated calls to the Scan +// function. Typical use: // // var s Scanner // fset := token.NewFileSet() // position information is relative to fset diff --git a/src/pkg/go/token/token.go b/src/pkg/go/token/token.go index a5f21df16..c2ec80ae1 100644 --- a/src/pkg/go/token/token.go +++ b/src/pkg/go/token/token.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package defines constants representing the lexical -// tokens of the Go programming language and basic operations -// on tokens (printing, predicates). +// Package token defines constants representing the lexical tokens of the Go +// programming language and basic operations on tokens (printing, predicates). // package token diff --git a/src/pkg/go/types/gcimporter.go b/src/pkg/go/types/gcimporter.go index 9e0ae6285..30adc04e7 100644 --- a/src/pkg/go/types/gcimporter.go +++ b/src/pkg/go/types/gcimporter.go @@ -461,7 +461,13 @@ func (p *gcParser) parseFuncType() Type { // MethodSpec = identifier Signature . // func (p *gcParser) parseMethodSpec(scope *ast.Scope) { - p.expect(scanner.Ident) + if p.tok == scanner.Ident { + p.expect(scanner.Ident) + } else { + p.parsePkgId() + p.expect('.') + p.parseDotIdent() + } isVariadic := false p.parseSignature(scope, &isVariadic) } diff --git a/src/pkg/go/types/types.go b/src/pkg/go/types/types.go index 72384e121..2ee645d98 100644 --- a/src/pkg/go/types/types.go +++ b/src/pkg/go/types/types.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // PACKAGE UNDER CONSTRUCTION. ANY AND ALL PARTS MAY CHANGE. -// The types package declares the types used to represent Go types. +// Package types declares the types used to represent Go types. // package types diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index 28042ccaa..8961336cd 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -999,13 +999,12 @@ type Bad0 struct { C float64 } - func TestInvalidField(t *testing.T) { var bad0 Bad0 bad0.CH = make(chan int) b := new(bytes.Buffer) dummyEncoder := new(Encoder) // sufficient for this purpose. - dummyEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) + dummyEncoder.encode(b, reflect.ValueOf(&bad0), userType(reflect.TypeOf(&bad0))) if err := dummyEncoder.err; err == nil { t.Error("expected error; got none") } else if strings.Index(err.String(), "type") < 0 { diff --git a/src/pkg/gob/debug.go b/src/pkg/gob/debug.go index 69c83bda7..79aee7788 100644 --- a/src/pkg/gob/debug.go +++ b/src/pkg/gob/debug.go @@ -335,7 +335,7 @@ func (deb *debugger) string() string { func (deb *debugger) delta(expect int) int { delta := int(deb.uint64()) if delta < 0 || (expect >= 0 && delta != expect) { - errorf("gob decode: corrupted type: delta %d expected %d", delta, expect) + errorf("decode: corrupted type: delta %d expected %d", delta, expect) } return delta } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 51fac798d..0e86df6b5 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -406,7 +406,7 @@ func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { func decString(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { - *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)) + *(*unsafe.Pointer)(p) = unsafe.Pointer(new(string)) } p = *(*unsafe.Pointer)(p) } @@ -468,7 +468,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) basep := p delta := int(state.decodeUint()) if delta != 0 { - errorf("gob decode: corrupted data: non-zero delta for singleton") + errorf("decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] ptr := unsafe.Pointer(basep) // offset will be zero @@ -493,7 +493,7 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, for state.b.Len() > 0 { delta := int(state.decodeUint()) if delta < 0 { - errorf("gob decode: corrupted data: negative delta") + errorf("decode: corrupted data: negative delta") } if delta == 0 { // struct terminator is zero delta fieldnum break @@ -521,7 +521,7 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) { for state.b.Len() > 0 { delta := int(state.decodeUint()) if delta < 0 { - errorf("gob ignore decode: corrupted data: negative delta") + errorf("ignore decode: corrupted data: negative delta") } if delta == 0 { // struct terminator is zero delta fieldnum break @@ -544,7 +544,7 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) { state.fieldnum = singletonField delta := int(state.decodeUint()) if delta != 0 { - errorf("gob decode: corrupted data: non-zero delta for singleton") + errorf("decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] instr.op(instr, state, unsafe.Pointer(nil)) @@ -572,7 +572,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } if n := state.decodeUint(); n != uint64(length) { - errorf("gob: length mismatch in decodeArray") + errorf("length mismatch in decodeArray") } dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } @@ -581,7 +581,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt // unlike the other items we can't use a pointer directly. func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} - up := unsafe.Pointer(v.UnsafeAddr()) + up := unsafe.Pointer(unsafeAddr(v)) if indir > 1 { up = decIndirect(up, indir) } @@ -605,11 +605,11 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr, // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) + v := reflect.ValueOf(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) n := int(state.decodeUint()) for i := 0; i < n; i++ { - key := decodeIntoValue(state, keyOp, keyIndir, reflect.Zero(mtyp.Key()), ovfl) - elem := decodeIntoValue(state, elemOp, elemIndir, reflect.Zero(mtyp.Elem()), ovfl) + key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl) + elem := decodeIntoValue(state, elemOp, elemIndir, allocValue(mtyp.Elem()), ovfl) v.SetMapIndex(key, elem) } } @@ -625,7 +625,7 @@ func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length // ignoreArray discards the data for an array value with no destination. func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) { if n := state.decodeUint(); n != uint64(length) { - errorf("gob: length mismatch in ignoreArray") + errorf("length mismatch in ignoreArray") } dec.ignoreArrayHelper(state, elemOp, length) } @@ -667,18 +667,12 @@ func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) { dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) } -// setInterfaceValue sets an interface value to a concrete value through -// reflection. If the concrete value does not implement the interface, the -// setting will panic. This routine turns the panic into an error return. -// This dance avoids manually checking that the value satisfies the -// interface. -// TODO(rsc): avoid panic+recover after fixing issue 327. +// setInterfaceValue sets an interface value to a concrete value, +// but first it checks that the assignment will succeed. func setInterfaceValue(ivalue reflect.Value, value reflect.Value) { - defer func() { - if e := recover(); e != nil { - error(e.(os.Error)) - } - }() + if !value.Type().AssignableTo(ivalue.Type()) { + errorf("cannot assign value of type %s to %s", value.Type(), ivalue.Type()) + } ivalue.Set(value) } @@ -686,8 +680,8 @@ func setInterfaceValue(ivalue reflect.Value, value reflect.Value) { // Interfaces are encoded as the name of a concrete type followed by a value. // If the name is empty, the value is nil and no value is sent. func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p uintptr, indir int) { - // Create an interface reflect.Value. We need one even for the nil case. - ivalue := reflect.Zero(ityp) + // Create a writable interface reflect.Value. We need one even for the nil case. + ivalue := allocValue(ityp) // Read the name of the concrete type. b := make([]byte, state.decodeUint()) state.b.Read(b) @@ -701,7 +695,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui // The concrete type must be registered. typ, ok := nameToConcreteType[name] if !ok { - errorf("gob: name not registered for interface: %q", name) + errorf("name not registered for interface: %q", name) } // Read the type id of the concrete value. concreteId := dec.decodeTypeSequence(true) @@ -712,7 +706,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui // in case we want to ignore the value by skipping it completely). state.decodeUint() // Read the concrete value. - value := reflect.Zero(typ) + value := allocValue(typ) dec.decodeValue(concreteId, value) if dec.err != nil { error(dec.err) @@ -880,7 +874,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg } } if op == nil { - errorf("gob: decode can't handle type %s", rt.String()) + errorf("decode can't handle type %s", rt.String()) } return &op, indir } @@ -901,7 +895,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { wire := dec.wireType[wireId] switch { case wire == nil: - errorf("gob: bad data: undefined type %s", wireId.string()) + errorf("bad data: undefined type %s", wireId.string()) case wire.ArrayT != nil: elemId := wire.ArrayT.Elem elemOp := dec.decIgnoreOpFor(elemId) @@ -943,7 +937,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { } } if op == nil { - errorf("gob: bad data: ignore can't handle type %s", wireId.string()) + errorf("bad data: ignore can't handle type %s", wireId.string()) } return op } @@ -951,32 +945,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { // gobDecodeOpFor returns the op for a type that is known to implement // GobDecoder. func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { - rt := ut.user + rcvrType := ut.user if ut.decIndir == -1 { - rt = reflect.PtrTo(rt) + rcvrType = reflect.PtrTo(rcvrType) } else if ut.decIndir > 0 { for i := int8(0); i < ut.decIndir; i++ { - rt = rt.Elem() + rcvrType = rcvrType.Elem() } } var op decOp op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { - // Allocate the underlying data, but hold on to the address we have, - // since we need it to get to the receiver's address. - allocate(ut.base, uintptr(p), ut.indir) + // Caller has gotten us to within one indirection of our value. + if i.indir > 0 { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe.New(ut.base) + } + } + // Now p is a pointer to the base type. Do we need to climb out to + // get to the receiver type? var v reflect.Value if ut.decIndir == -1 { - // Need to climb up one level to turn value into pointer. - v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + v = reflect.ValueOf(unsafe.Unreflect(rcvrType, unsafe.Pointer(&p))) } else { - if ut.decIndir > 0 { - p = decIndirect(p, int(ut.decIndir)) - } - v = reflect.NewValue(unsafe.Unreflect(rt, p)) + v = reflect.ValueOf(unsafe.Unreflect(rcvrType, p)) } - state.dec.decodeGobDecoder(state, v, methodIndex(rt, gobDecodeMethodName)) + state.dec.decodeGobDecoder(state, v, methodIndex(rcvrType, gobDecodeMethodName)) } - return &op, int(ut.decIndir) + return &op, int(ut.indir) } @@ -1111,7 +1106,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn wireStruct = wire.StructT } if wireStruct == nil { - errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String()) + errorf("type mismatch in decoder: want struct type %s; got non-struct", rt.String()) } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.Field)) @@ -1120,7 +1115,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { wireField := wireStruct.Field[fieldnum] if wireField.Name == "" { - errorf("gob: empty name for remote field of type %s", wireStruct.Name) + errorf("empty name for remote field of type %s", wireStruct.Name) } ovfl := overflow(wireField.Name) // Find the field of the local type with the same name. @@ -1132,7 +1127,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn continue } if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { - errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) + errorf("wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl} @@ -1164,7 +1159,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePt // emptyStruct is the type we compile into when ignoring a struct value. type emptyStruct struct{} -var emptyStructType = reflect.Typeof(emptyStruct{}) +var emptyStructType = reflect.TypeOf(emptyStruct{}) // getDecEnginePtr returns the engine for the specified type when the value is to be discarded. func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { @@ -1197,10 +1192,6 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { // Dereference down to the underlying struct type. ut := userType(val.Type()) base := ut.base - indir := ut.indir - if ut.isGobDecoder { - indir = int(ut.decIndir) - } var enginePtr **decEngine enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut) if dec.err != nil { @@ -1210,11 +1201,11 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { if st := base; st.Kind() == reflect.Struct && !ut.isGobDecoder { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { name := base.Name() - errorf("gob: type mismatch: no fields matched compiling decoder for %s", name) + errorf("type mismatch: no fields matched compiling decoder for %s", name) } - dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir) + dec.decodeStruct(engine, ut, uintptr(unsafeAddr(val)), ut.indir) } else { - dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) + dec.decodeSingle(engine, ut, uintptr(unsafeAddr(val))) } } @@ -1235,7 +1226,7 @@ func (dec *Decoder) decodeIgnoredValue(wireId typeId) { func init() { var iop, uop decOp - switch reflect.Typeof(int(0)).Bits() { + switch reflect.TypeOf(int(0)).Bits() { case 32: iop = decInt32 uop = decUint32 @@ -1249,7 +1240,7 @@ func init() { decOpTable[reflect.Uint] = uop // Finally uintptr - switch reflect.Typeof(uintptr(0)).Bits() { + switch reflect.TypeOf(uintptr(0)).Bits() { case 32: uop = decUint32 case 64: @@ -1259,3 +1250,26 @@ func init() { } decOpTable[reflect.Uintptr] = uop } + +// Gob assumes it can call UnsafeAddr on any Value +// in order to get a pointer it can copy data from. +// Values that have just been created and do not point +// into existing structs or slices cannot be addressed, +// so simulate it by returning a pointer to a copy. +// Each call allocates once. +func unsafeAddr(v reflect.Value) uintptr { + if v.CanAddr() { + return v.UnsafeAddr() + } + x := reflect.New(v.Type()).Elem() + x.Set(v) + return x.UnsafeAddr() +} + +// Gob depends on being able to take the address +// of zeroed Values it creates, so use this wrapper instead +// of the standard reflect.Zero. +// Each call allocates once. +func allocValue(t reflect.Type) reflect.Value { + return reflect.New(t).Elem() +} diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index a631c27a2..ea2f62ec5 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -50,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) { // Type: wire := new(wireType) - dec.decodeValue(tWireType, reflect.NewValue(wire)) + dec.decodeValue(tWireType, reflect.ValueOf(wire)) if dec.err != nil { return } @@ -161,7 +161,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error { if e == nil { return dec.DecodeValue(reflect.Value{}) } - value := reflect.NewValue(e) + value := reflect.ValueOf(e) // If e represents a value as opposed to a pointer, the answer won't // get back to the caller. Make sure it's a pointer. if value.Type().Kind() != reflect.Ptr { @@ -171,12 +171,18 @@ func (dec *Decoder) Decode(e interface{}) os.Error { return dec.DecodeValue(value) } -// DecodeValue reads the next value from the connection and stores -// it in the data represented by the reflection value. -// The value must be the correct type for the next -// data item received, or it may be nil, which means the -// value will be discarded. -func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { +// DecodeValue reads the next value from the connection. +// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value. +// Otherwise, it stores the value into v. In that case, v must represent +// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet()) +func (dec *Decoder) DecodeValue(v reflect.Value) os.Error { + if v.IsValid() { + if v.Kind() == reflect.Ptr && !v.IsNil() { + // That's okay, we'll store through the pointer. + } else if !v.CanSet() { + return os.ErrorString("gob: DecodeValue of unassignable value") + } + } // Make sure we're single-threaded through here. dec.mutex.Lock() defer dec.mutex.Unlock() @@ -185,7 +191,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { dec.err = nil id := dec.decodeTypeSequence(false) if dec.err == nil { - dec.decodeValue(id, value) + dec.decodeValue(id, v) } return dec.err } diff --git a/src/pkg/gob/doc.go b/src/pkg/gob/doc.go index 613974a00..189086f52 100644 --- a/src/pkg/gob/doc.go +++ b/src/pkg/gob/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The gob package manages streams of gobs - binary values exchanged between an +Package gob manages streams of gobs - binary values exchanged between an Encoder (transmitter) and a Decoder (receiver). A typical use is transporting arguments and results of remote procedure calls (RPCs) such as those provided by package "rpc". diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 36bde08aa..f9e691a2f 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -384,7 +384,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui up := unsafe.Pointer(elemp) if elemIndir > 0 { if up = encIndirect(up, elemIndir); up == nil { - errorf("gob: encodeArray: nil element") + errorf("encodeArray: nil element") } elemp = uintptr(up) } @@ -400,9 +400,9 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in v = reflect.Indirect(v) } if !v.IsValid() { - errorf("gob: encodeReflectValue: nil element") + errorf("encodeReflectValue: nil element") } - op(nil, state, unsafe.Pointer(v.UnsafeAddr())) + op(nil, state, unsafe.Pointer(unsafeAddr(v))) } // encodeMap encodes a map as unsigned count followed by key:value pairs. @@ -438,7 +438,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv reflect.Value) { ut := userType(iv.Elem().Type()) name, ok := concreteTypeToName[ut.base] if !ok { - errorf("gob: type not registered for interface: %s", ut.base) + errorf("type not registered for interface: %s", ut.base) } // Send the name. state.encodeUint(uint64(len(name))) @@ -555,7 +555,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) + v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) mv := reflect.Indirect(v) if !state.sendZero && mv.Len() == 0 { return @@ -576,7 +576,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Interfaces transmit the name and contents of the concrete // value they contain. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) + v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) iv := reflect.Indirect(v) if !state.sendZero && (!iv.IsValid() || iv.IsNil()) { return @@ -587,7 +587,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp } } if op == nil { - errorf("gob enc: can't happen: encode type %s", rt.String()) + errorf("can't happen: encode type %s", rt.String()) } return &op, indir } @@ -599,7 +599,7 @@ func methodIndex(rt reflect.Type, method string) int { return i } } - errorf("gob: internal error: can't find method %s", method) + errorf("internal error: can't find method %s", method) return 0 } @@ -619,9 +619,9 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { var v reflect.Value if ut.encIndir == -1 { // Need to climb up one level to turn value into pointer. - v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + v = reflect.ValueOf(unsafe.Unreflect(rt, unsafe.Pointer(&p))) } else { - v = reflect.NewValue(unsafe.Unreflect(rt, p)) + v = reflect.ValueOf(unsafe.Unreflect(rt, p)) } state.update(i) state.enc.encodeGobEncoder(state.b, v, methodIndex(rt, gobEncodeMethodName)) @@ -650,7 +650,7 @@ func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { wireFieldNum++ } if srt.NumField() > 0 && len(engine.instr) == 0 { - errorf("gob: type %s has no exported fields", rt) + errorf("type %s has no exported fields", rt) } engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) } else { @@ -695,8 +695,8 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf value = reflect.Indirect(value) } if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { - enc.encodeStruct(b, engine, value.UnsafeAddr()) + enc.encodeStruct(b, engine, unsafeAddr(value)) } else { - enc.encodeSingle(b, engine, value.UnsafeAddr()) + enc.encodeSingle(b, engine, unsafeAddr(value)) } } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 928f3b244..65ee5bf67 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -97,7 +97,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp // Id: state.encodeInt(-int64(info.id)) // Type: - enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo) enc.writeMessage(w, state.b) if enc.err != nil { return @@ -116,6 +116,9 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp } case reflect.Array, reflect.Slice: enc.sendType(w, state, st.Elem()) + case reflect.Map: + enc.sendType(w, state, st.Key()) + enc.sendType(w, state, st.Elem()) } return true } @@ -162,7 +165,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ // Encode transmits the data item represented by the empty interface value, // guaranteeing that all necessary type information has been transmitted first. func (enc *Encoder) Encode(e interface{}) os.Error { - return enc.EncodeValue(reflect.NewValue(e)) + return enc.EncodeValue(reflect.ValueOf(e)) } // sendTypeDescriptor makes sure the remote side knows about this type. diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 3d5dfdb86..792afbd77 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -170,7 +170,7 @@ func TestTypeToPtrType(t *testing.T) { A int } t0 := Type0{7} - t0p := (*Type0)(nil) + t0p := new(Type0) if err := encAndDec(t0, t0p); err != nil { t.Error(err) } @@ -339,7 +339,7 @@ func TestSingletons(t *testing.T) { continue } // Get rid of the pointer in the rhs - val := reflect.NewValue(test.out).Elem().Interface() + val := reflect.ValueOf(test.out).Elem().Interface() if !reflect.DeepEqual(test.in, val) { t.Errorf("decoding singleton: expected %v got %v", test.in, val) } @@ -514,3 +514,38 @@ func TestNestedInterfaces(t *testing.T) { t.Fatalf("final value %d; expected %d", inner.A, 7) } } + +// The bugs keep coming. We forgot to send map subtypes before the map. + +type Bug1Elem struct { + Name string + Id int +} + +type Bug1StructMap map[string]Bug1Elem + +func bug1EncDec(in Bug1StructMap, out *Bug1StructMap) os.Error { + return nil +} + +func TestMapBug1(t *testing.T) { + in := make(Bug1StructMap) + in["val1"] = Bug1Elem{"elem1", 1} + in["val2"] = Bug1Elem{"elem2", 2} + + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(in) + if err != nil { + t.Fatal("encode:", err) + } + dec := NewDecoder(b) + out := make(Bug1StructMap) + err = dec.Decode(&out) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(in, out) { + t.Errorf("mismatch: %v %v", in, out) + } +} diff --git a/src/pkg/gob/error.go b/src/pkg/gob/error.go index b053761fb..bfd38fc16 100644 --- a/src/pkg/gob/error.go +++ b/src/pkg/gob/error.go @@ -22,8 +22,9 @@ type gobError struct { } // errorf is like error but takes Printf-style arguments to construct an os.Error. +// It always prefixes the message with "gob: ". func errorf(format string, args ...interface{}) { - error(fmt.Errorf(format, args...)) + error(fmt.Errorf("gob: "+format, args...)) } // error wraps the argument error and uses it as the argument to panic. diff --git a/src/pkg/gob/gobencdec_test.go b/src/pkg/gob/gobencdec_test.go index 012b09956..e94534f4c 100644 --- a/src/pkg/gob/gobencdec_test.go +++ b/src/pkg/gob/gobencdec_test.go @@ -24,6 +24,10 @@ type StringStruct struct { s string // not an exported field } +type ArrayStruct struct { + a [8192]byte // not an exported field +} + type Gobber int type ValueGobber string // encodes with a value, decodes with a pointer. @@ -74,6 +78,18 @@ func (g *StringStruct) GobDecode(data []byte) os.Error { return nil } +func (a *ArrayStruct) GobEncode() ([]byte, os.Error) { + return a.a[:], nil +} + +func (a *ArrayStruct) GobDecode(data []byte) os.Error { + if len(data) != len(a.a) { + return os.ErrorString("wrong length in array decode") + } + copy(a.a[:], data) + return nil +} + func (g *Gobber) GobEncode() ([]byte, os.Error) { return []byte(fmt.Sprintf("VALUE=%d", *g)), nil } @@ -138,6 +154,16 @@ type GobTestIndirectEncDec struct { G ***StringStruct // indirections to the receiver. } +type GobTestArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ArrayStruct // not a pointer. +} + +type GobTestIndirectArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ***ArrayStruct // indirections to a large receiver. +} + func TestGobEncoderField(t *testing.T) { b := new(bytes.Buffer) // First a field that's a structure. @@ -216,6 +242,64 @@ func TestGobEncoderIndirectField(t *testing.T) { } } +// Test with a large field with methods. +func TestGobEncoderArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestArrayEncDec + a.X = 17 + for i := range a.A.a { + a.A.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range x.A.a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + +// Test an indirection to a large field with methods. +func TestGobEncoderIndirectArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestIndirectArrayEncDec + a.X = 17 + var array ArrayStruct + ap := &array + app := &ap + a.A = &app + for i := range array.a { + array.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIndirectArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range (***x.A).a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + // As long as the fields have the same name and implement the // interface, we can cross-connect them. Not sure it's useful // and may even be bad but it works and it's hard to prevent diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 8fd174841..c5b8fb5d9 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -74,8 +74,8 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { } ut.indir++ } - ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck) - ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck) + ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderInterfaceType) + ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderInterfaceType) userTypeCache[rt] = ut return } @@ -85,32 +85,16 @@ const ( gobDecodeMethodName = "GobDecode" ) -// implements returns whether the type implements the interface, as encoded -// in the check function. -func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool { - if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance - return false - } - return check(typ) -} - -// gobEncoderCheck makes the type assertion a boolean function. -func gobEncoderCheck(typ reflect.Type) bool { - _, ok := reflect.Zero(typ).Interface().(GobEncoder) - return ok -} - -// gobDecoderCheck makes the type assertion a boolean function. -func gobDecoderCheck(typ reflect.Type) bool { - _, ok := reflect.Zero(typ).Interface().(GobDecoder) - return ok -} +var ( + gobEncoderInterfaceType = reflect.TypeOf(new(GobEncoder)).Elem() + gobDecoderInterfaceType = reflect.TypeOf(new(GobDecoder)).Elem() +) // implementsInterface reports whether the type implements the -// interface. (The actual check is done through the provided function.) +// gobEncoder/gobDecoder interface. // It also returns the number of indirections required to get to the // implementation. -func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) { +func implementsInterface(typ, gobEncDecType reflect.Type) (success bool, indir int8) { if typ == nil { return } @@ -118,7 +102,7 @@ func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (s // The type might be a pointer and we need to keep // dereferencing to the base type until we find an implementation. for { - if implements(rt, check) { + if rt.Implements(gobEncDecType) { return true, indir } if p := rt; p.Kind() == reflect.Ptr { @@ -134,7 +118,7 @@ func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (s // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. if typ.Kind() != reflect.Ptr { // Not a pointer, but does the pointer work? - if implements(reflect.PtrTo(typ), check) { + if reflect.PtrTo(typ).Implements(gobEncDecType) { return true, -1 } } @@ -243,18 +227,18 @@ var ( ) // Predefined because it's needed by the Decoder -var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var tWireType = mustGetTypeInfo(reflect.TypeOf(wireType{})).id var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) func init() { // Some magic numbers to make sure there are no surprises. checkId(16, tWireType) - checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id) - checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id) - checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id) - checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id) - checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) - checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id) + checkId(17, mustGetTypeInfo(reflect.TypeOf(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.TypeOf(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.TypeOf(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.TypeOf(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.TypeOf(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.TypeOf(mapType{})).id) builtinIdToType = make(map[typeId]gobType) for k, v := range idToType { @@ -268,7 +252,7 @@ func init() { } nextId = firstUserId registerBasics() - wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) + wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil))) } // Array type @@ -569,7 +553,7 @@ func checkId(want, got typeId) { // used for building the basic types; called only from init(). the incoming // interface always refers to a pointer. func bootstrapType(name string, e interface{}, expect typeId) typeId { - rt := reflect.Typeof(e).Elem() + rt := reflect.TypeOf(e).Elem() _, present := types[rt] if present { panic("bootstrap type already present: " + name + ", " + rt.String()) @@ -723,7 +707,7 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } - base := userType(reflect.Typeof(value)).base + base := userType(reflect.TypeOf(value)).base // Check for incompatible duplicates. if t, ok := nameToConcreteType[name]; ok && t != base { panic("gob: registering duplicate types for " + name) @@ -732,7 +716,7 @@ func RegisterName(name string, value interface{}) { panic("gob: registering duplicate names for " + base.String()) } // Store the name and type provided by the user.... - nameToConcreteType[name] = reflect.Typeof(value) + nameToConcreteType[name] = reflect.TypeOf(value) // but the flattened type in the type table, since that's what decode needs. concreteTypeToName[base] = name } @@ -745,7 +729,7 @@ func RegisterName(name string, value interface{}) { // between types and names is not a bijection. func Register(value interface{}) { // Default to printed representation for unnamed types - rt := reflect.Typeof(value) + rt := reflect.TypeOf(value) name := rt.String() // But for named types (or pointers to them), qualify with import path. diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go index ffd1345e5..411ffb797 100644 --- a/src/pkg/gob/type_test.go +++ b/src/pkg/gob/type_test.go @@ -47,15 +47,15 @@ func TestBasic(t *testing.T) { // Reregister some basic types to check registration is idempotent. func TestReregistration(t *testing.T) { - newtyp := getTypeUnlocked("int", reflect.Typeof(int(0))) + newtyp := getTypeUnlocked("int", reflect.TypeOf(int(0))) if newtyp != tInt.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } - newtyp = getTypeUnlocked("uint", reflect.Typeof(uint(0))) + newtyp = getTypeUnlocked("uint", reflect.TypeOf(uint(0))) if newtyp != tUint.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } - newtyp = getTypeUnlocked("string", reflect.Typeof("hello")) + newtyp = getTypeUnlocked("string", reflect.TypeOf("hello")) if newtyp != tString.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } @@ -63,18 +63,18 @@ func TestReregistration(t *testing.T) { func TestArrayType(t *testing.T) { var a3 [3]int - a3int := getTypeUnlocked("foo", reflect.Typeof(a3)) - newa3int := getTypeUnlocked("bar", reflect.Typeof(a3)) + a3int := getTypeUnlocked("foo", reflect.TypeOf(a3)) + newa3int := getTypeUnlocked("bar", reflect.TypeOf(a3)) if a3int != newa3int { t.Errorf("second registration of [3]int creates new type") } var a4 [4]int - a4int := getTypeUnlocked("goo", reflect.Typeof(a4)) + a4int := getTypeUnlocked("goo", reflect.TypeOf(a4)) if a3int == a4int { t.Errorf("registration of [3]int creates same type as [4]int") } var b3 [3]bool - a3bool := getTypeUnlocked("", reflect.Typeof(b3)) + a3bool := getTypeUnlocked("", reflect.TypeOf(b3)) if a3int == a3bool { t.Errorf("registration of [3]bool creates same type as [3]int") } @@ -87,14 +87,14 @@ func TestArrayType(t *testing.T) { func TestSliceType(t *testing.T) { var s []int - sint := getTypeUnlocked("slice", reflect.Typeof(s)) + sint := getTypeUnlocked("slice", reflect.TypeOf(s)) var news []int - newsint := getTypeUnlocked("slice1", reflect.Typeof(news)) + newsint := getTypeUnlocked("slice1", reflect.TypeOf(news)) if sint != newsint { t.Errorf("second registration of []int creates new type") } var b []bool - sbool := getTypeUnlocked("", reflect.Typeof(b)) + sbool := getTypeUnlocked("", reflect.TypeOf(b)) if sbool == sint { t.Errorf("registration of []bool creates same type as []int") } @@ -107,14 +107,14 @@ func TestSliceType(t *testing.T) { func TestMapType(t *testing.T) { var m map[string]int - mapStringInt := getTypeUnlocked("map", reflect.Typeof(m)) + mapStringInt := getTypeUnlocked("map", reflect.TypeOf(m)) var newm map[string]int - newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm)) + newMapStringInt := getTypeUnlocked("map1", reflect.TypeOf(newm)) if mapStringInt != newMapStringInt { t.Errorf("second registration of map[string]int creates new type") } var b map[string]bool - mapStringBool := getTypeUnlocked("", reflect.Typeof(b)) + mapStringBool := getTypeUnlocked("", reflect.TypeOf(b)) if mapStringBool == mapStringInt { t.Errorf("registration of map[string]bool creates same type as map[string]int") } @@ -143,7 +143,7 @@ type Foo struct { } func TestStructType(t *testing.T) { - sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{})) + sstruct := getTypeUnlocked("Foo", reflect.TypeOf(Foo{})) str := sstruct.string() // If we can print it correctly, we built it correctly. expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }" diff --git a/src/pkg/hash/adler32/adler32.go b/src/pkg/hash/adler32/adler32.go index cd0c2599a..84943d9ae 100644 --- a/src/pkg/hash/adler32/adler32.go +++ b/src/pkg/hash/adler32/adler32.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the Adler-32 checksum. +// Package adler32 implements the Adler-32 checksum. // Defined in RFC 1950: // Adler-32 is composed of two sums accumulated per byte: s1 is // the sum of all bytes, s2 is the sum of all s1 values. Both sums @@ -43,8 +43,8 @@ func (d *digest) Size() int { return Size } // Add p to the running checksum a, b. func update(a, b uint32, p []byte) (aa, bb uint32) { - for i := 0; i < len(p); i++ { - a += uint32(p[i]) + for _, pi := range p { + a += uint32(pi) b += a // invariant: a <= b if b > (0xffffffff-255)/2 { diff --git a/src/pkg/hash/adler32/adler32_test.go b/src/pkg/hash/adler32/adler32_test.go index ffa5569bc..01f931c68 100644 --- a/src/pkg/hash/adler32/adler32_test.go +++ b/src/pkg/hash/adler32/adler32_test.go @@ -5,6 +5,7 @@ package adler32 import ( + "bytes" "io" "testing" ) @@ -61,3 +62,16 @@ func TestGolden(t *testing.T) { } } } + +func BenchmarkGolden(b *testing.B) { + b.StopTimer() + c := New() + var buf bytes.Buffer + for _, g := range golden { + buf.Write([]byte(g.in)) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + c.Write(buf.Bytes()) + } +} diff --git a/src/pkg/hash/crc32/crc32.go b/src/pkg/hash/crc32/crc32.go index 2ab0c5491..88a449971 100644 --- a/src/pkg/hash/crc32/crc32.go +++ b/src/pkg/hash/crc32/crc32.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the 32-bit cyclic redundancy check, or CRC-32, checksum. -// See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for information. +// Package crc32 implements the 32-bit cyclic redundancy check, or CRC-32, +// checksum. See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for +// information. package crc32 import ( diff --git a/src/pkg/hash/crc64/crc64.go b/src/pkg/hash/crc64/crc64.go index 844386564..ae37e781c 100644 --- a/src/pkg/hash/crc64/crc64.go +++ b/src/pkg/hash/crc64/crc64.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the 64-bit cyclic redundancy check, or CRC-64, checksum. -// See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for information. +// Package crc64 implements the 64-bit cyclic redundancy check, or CRC-64, +// checksum. See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for +// information. package crc64 import ( diff --git a/src/pkg/hash/fnv/fnv.go b/src/pkg/hash/fnv/fnv.go index 66ab5a635..9a1c6a0f2 100644 --- a/src/pkg/hash/fnv/fnv.go +++ b/src/pkg/hash/fnv/fnv.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The fnv package implements FNV-1 and FNV-1a, -// non-cryptographic hash functions created by -// Glenn Fowler, Landon Curt Noll, and Phong Vo. +// Package fnv implements FNV-1 and FNV-1a, non-cryptographic hash functions +// created by Glenn Fowler, Landon Curt Noll, and Phong Vo. // See http://isthe.com/chongo/tech/comp/fnv/. package fnv diff --git a/src/pkg/hash/hash.go b/src/pkg/hash/hash.go index 56ac259db..3536c0b6a 100644 --- a/src/pkg/hash/hash.go +++ b/src/pkg/hash/hash.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package hash provides interfaces for hash functions. package hash import "io" diff --git a/src/pkg/html/doc.go b/src/pkg/html/doc.go index 4f5dee72d..55135c3d0 100644 --- a/src/pkg/html/doc.go +++ b/src/pkg/html/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The html package implements an HTML5-compliant tokenizer and parser. +Package html implements an HTML5-compliant tokenizer and parser. Tokenization is done by creating a Tokenizer for an io.Reader r. It is the caller's responsibility to ensure that r provides UTF-8 encoded HTML. diff --git a/src/pkg/html/parse_test.go b/src/pkg/html/parse_test.go index fe955436c..3fa35d5db 100644 --- a/src/pkg/html/parse_test.go +++ b/src/pkg/html/parse_test.go @@ -15,12 +15,6 @@ import ( "testing" ) -type devNull struct{} - -func (devNull) Write(p []byte) (int, os.Error) { - return len(p), nil -} - func pipeErr(err os.Error) io.Reader { pr, pw := io.Pipe() pw.CloseWithError(err) @@ -141,7 +135,7 @@ func TestParser(t *testing.T) { t.Fatal(err) } // Skip the #error section. - if _, err := io.Copy(devNull{}, <-rc); err != nil { + if _, err := io.Copy(ioutil.Discard, <-rc); err != nil { t.Fatal(err) } // Compare the parsed tree to the #document section. diff --git a/src/pkg/http/Makefile b/src/pkg/http/Makefile index 389b04222..2a2a2a3be 100644 --- a/src/pkg/http/Makefile +++ b/src/pkg/http/Makefile @@ -16,6 +16,7 @@ GOFILES=\ persist.go\ request.go\ response.go\ + reverseproxy.go\ server.go\ status.go\ transfer.go\ diff --git a/src/pkg/http/cgi/host.go b/src/pkg/http/cgi/host.go index a713d7c3c..136d4e4ee 100644 --- a/src/pkg/http/cgi/host.go +++ b/src/pkg/http/cgi/host.go @@ -25,20 +25,40 @@ import ( "os" "path/filepath" "regexp" + "runtime" "strconv" "strings" ) var trailingPort = regexp.MustCompile(`:([0-9]+)$`) +var osDefaultInheritEnv = map[string][]string{ + "darwin": []string{"DYLD_LIBRARY_PATH"}, + "freebsd": []string{"LD_LIBRARY_PATH"}, + "hpux": []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}, + "linux": []string{"LD_LIBRARY_PATH"}, + "windows": []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, +} + // Handler runs an executable in a subprocess with a CGI environment. type Handler struct { Path string // path to the CGI executable Root string // root URI prefix of handler or empty for "/" - Env []string // extra environment variables to set, if any - Logger *log.Logger // optional log for errors or nil to use log.Print - Args []string // optional arguments to pass to child process + Env []string // extra environment variables to set, if any, as "key=value" + InheritEnv []string // environment variables to inherit from host, as "key" + Logger *log.Logger // optional log for errors or nil to use log.Print + Args []string // optional arguments to pass to child process + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 § 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -110,6 +130,24 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, h.Env...) } + path := os.Getenv("PATH") + if path == "" { + path = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" + } + env = append(env, "PATH="+path) + + for _, e := range h.InheritEnv { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + for _, e := range osDefaultInheritEnv[runtime.GOOS] { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + cwd, pathBase := filepath.Split(h.Path) if cwd == "" { cwd = "." @@ -143,13 +181,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024) - headers := rw.Header() - statusCode := http.StatusOK + headers := make(http.Header) + statusCode := 0 for { line, isPrefix, err := linebody.ReadLine() if isPrefix { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: long header line from subprocess.") + h.printf("cgi: long header line from subprocess.") return } if err == os.EOF { @@ -157,7 +195,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if err != nil { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: error reading headers: %v", err) + h.printf("cgi: error reading headers: %v", err) return } if len(line) == 0 { @@ -165,7 +203,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } parts := strings.Split(string(line), ":", 2) if len(parts) < 2 { - h.printf("CGI: bogus header line: %s", string(line)) + h.printf("cgi: bogus header line: %s", string(line)) continue } header, val := parts[0], parts[1] @@ -174,13 +212,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch { case header == "Status": if len(val) < 3 { - h.printf("CGI: bogus status (short): %q", val) + h.printf("cgi: bogus status (short): %q", val) return } code, err := strconv.Atoi(val[0:3]) if err != nil { - h.printf("CGI: bogus status: %q", val) - h.printf("CGI: line was %q", line) + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) return } statusCode = code @@ -188,11 +226,35 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { headers.Add(header, val) } } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + rw.WriteHeader(statusCode) _, err = io.Copy(rw, linebody) if err != nil { - h.printf("CGI: copy error: %v", err) + h.printf("cgi: copy error: %v", err) } } @@ -204,6 +266,37 @@ func (h *Handler) printf(format string, v ...interface{}) { } } +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.ParseURL(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + RawURL: path, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + func upperCaseAndUnderscore(rune int) int { switch { case rune >= 'a' && rune <= 'z': diff --git a/src/pkg/http/cgi/host_test.go b/src/pkg/http/cgi/host_test.go index e8084b113..9ac085f2f 100644 --- a/src/pkg/http/cgi/host_test.go +++ b/src/pkg/http/cgi/host_test.go @@ -271,3 +271,40 @@ Transfer-Encoding: chunked expected, got) } } + +func TestRedirect(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + if skipTest(t) { + return + } + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/src/pkg/http/cgi/testdata/test.cgi b/src/pkg/http/cgi/testdata/test.cgi index 253589eed..a1b2ff893 100755 --- a/src/pkg/http/cgi/testdata/test.cgi +++ b/src/pkg/http/cgi/testdata/test.cgi @@ -11,6 +11,11 @@ use CGI; my $q = CGI->new; my $params = $q->Vars; +if ($params->{"loc"}) { + print "Location: $params->{loc}\r\n\r\n"; + exit(0); +} + my $NL = "\r\n"; $NL = "\n" if $params->{mode} eq "NL"; diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index daba3a89b..d73cbc855 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -22,6 +22,16 @@ import ( // Client is not yet very configurable. type Client struct { Transport RoundTripper // if nil, DefaultTransport is used + + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via + // are the upcoming request and the requests made already, + // oldest first. If CheckRedirect returns an error, the client + // returns that error instead of issue the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) os.Error } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool { } // Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // // 301 (Moved Permanently) // 302 (Found) @@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { return DefaultClient.Get(url) } -// Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. // // 301 (Moved Permanently) // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical to the -// input URL unless redirects were followed. +// finalURL is the URL from which the response was fetched -- identical +// to the input URL unless redirects were followed. // // Caller should close r.Body when done reading from it. func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. - // TODO: set referrer header on redirects. var base *URL - // TODO: remove this hard-coded 10 and use the Client's policy - // (ClientConfig) instead. - for redirect := 0; ; redirect++ { - if redirect >= 10 { - err = os.ErrorString("stopped after 10 redirects") - break - } + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + for redirect := 0; ; redirect++ { var req Request req.Method = "GET" - req.ProtoMajor = 1 - req.ProtoMinor = 1 + req.Header = make(Header) if base == nil { req.URL, err = ParseURL(url) } else { @@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { if err != nil { break } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Referer = lastReq.URL.String() + } + + err = redirectChecker(&req, via) + if err != nil { + break + } + } + url = req.URL.String() if r, err = send(&req, c.Transport); err != nil { break @@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { break } base = req.URL + via = append(via, &req) continue } finalURL = url @@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { return } +func defaultCheckRedirect(req *Request, via []*Request) os.Error { + if len(via) >= 10 { + return os.ErrorString("stopped after 10 redirects") + } + return nil +} + // Post issues a POST to the specified URL. // // Caller should close r.Body when done reading from it. diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index 3a6f83425..59d62c1c9 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -12,6 +12,7 @@ import ( "http/httptest" "io/ioutil" "os" + "strconv" "strings" "testing" ) @@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) { t.Errorf("expected non-nil request Header") } } + +func TestRedirects(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer, ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client, expected error %q, got %q", e, g) + } + + var checkErr os.Error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error { + lastVia = via + return checkErr + }} + _, finalUrl, err := c.Get(ts.URL) + if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = os.NewError("no redirects allowed") + _, finalUrl, err = c.Get(ts.URL) + if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { + t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + } +} diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go index 2bb66e58e..2c01826a1 100644 --- a/src/pkg/http/cookie.go +++ b/src/pkg/http/cookie.go @@ -142,12 +142,12 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { var b bytes.Buffer for _, c := range kk { b.Reset() - fmt.Fprintf(&b, "%s=%s", c.Name, c.Value) + fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path)) + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) } if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain)) + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) } if len(c.Expires.Zone) > 0 { fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) @@ -225,7 +225,7 @@ func readCookies(h Header) []*Cookie { func writeCookies(w io.Writer, kk []*Cookie) os.Error { lines := make([]string, 0, len(kk)) for _, c := range kk { - lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value)) + lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", sanitizeName(c.Name), sanitizeValue(c.Value))) } sort.SortStrings(lines) for _, l := range lines { @@ -236,6 +236,19 @@ func writeCookies(w io.Writer, kk []*Cookie) os.Error { return nil } +func sanitizeName(n string) string { + n = strings.Replace(n, "\n", "-", -1) + n = strings.Replace(n, "\r", "-", -1) + return n +} + +func sanitizeValue(v string) string { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.Replace(v, ";", " ", -1) + return v +} + func unquoteCookieValue(v string) string { if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { return v[1 : len(v)-1] diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go index db0997040..a3ae85cd6 100644 --- a/src/pkg/http/cookie_test.go +++ b/src/pkg/http/cookie_test.go @@ -21,9 +21,13 @@ var writeSetCookiesTests = []struct { []*Cookie{ &Cookie{Name: "cookie-1", Value: "v$1"}, &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, }, "Set-Cookie: cookie-1=v$1\r\n" + - "Set-Cookie: cookie-2=two; Max-Age=3600\r\n", + "Set-Cookie: cookie-2=two; Max-Age=3600\r\n" + + "Set-Cookie: cookie-3=three; Domain=.example.com\r\n" + + "Set-Cookie: cookie-4=four; Path=/restricted/\r\n", }, } diff --git a/src/pkg/http/dump.go b/src/pkg/http/dump.go index 306c45bc2..358980f7c 100644 --- a/src/pkg/http/dump.go +++ b/src/pkg/http/dump.go @@ -31,6 +31,8 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) { // DumpRequest is semantically a no-op, but in order to // dump the body, it reads the body data into memory and // changes req.Body to refer to the in-memory copy. +// The documentation for Request.Write details which fields +// of req are used. func DumpRequest(req *Request, body bool) (dump []byte, err os.Error) { var b bytes.Buffer save := req.Body diff --git a/src/pkg/http/export_test.go b/src/pkg/http/export_test.go index 47c687760..3fe658641 100644 --- a/src/pkg/http/export_test.go +++ b/src/pkg/http/export_test.go @@ -32,3 +32,10 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { } return len(conns) } + +func NewTestTimeoutHandler(handler Handler, ch <-chan int64) Handler { + f := func() <-chan int64 { + return ch + } + return &timeoutHandler{handler, f, ""} +} diff --git a/src/pkg/http/fcgi/Makefile b/src/pkg/http/fcgi/Makefile new file mode 100644 index 000000000..bc01cdea9 --- /dev/null +++ b/src/pkg/http/fcgi/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=http/fcgi +GOFILES=\ + child.go\ + fcgi.go\ + +include ../../../Make.pkg diff --git a/src/pkg/http/fcgi/child.go b/src/pkg/http/fcgi/child.go new file mode 100644 index 000000000..114052bee --- /dev/null +++ b/src/pkg/http/fcgi/child.go @@ -0,0 +1,328 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "fmt" + "http" + "io" + "net" + "os" + "strconv" + "strings" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// TODO(eds): copied from http/cgi +var skipHeader = map[string]bool{ + "HTTP_HOST": true, + "HTTP_REFERER": true, + "HTTP_USER_AGENT": true, +} + +// httpRequest converts r to an http.Request. +// TODO(eds): this is very similar to http/cgi's requestFromEnvironment +func (r *request) httpRequest(body io.ReadCloser) (*http.Request, os.Error) { + req := &http.Request{ + Method: r.params["REQUEST_METHOD"], + RawURL: r.params["REQUEST_URI"], + Body: body, + Header: http.Header{}, + Trailer: http.Header{}, + Proto: r.params["SERVER_PROTOCOL"], + } + + var ok bool + req.ProtoMajor, req.ProtoMinor, ok = http.ParseHTTPVersion(req.Proto) + if !ok { + return nil, os.NewError("fcgi: invalid HTTP version") + } + + req.Host = r.params["HTTP_HOST"] + req.Referer = r.params["HTTP_REFERER"] + req.UserAgent = r.params["HTTP_USER_AGENT"] + + if lenstr := r.params["CONTENT_LENGTH"]; lenstr != "" { + clen, err := strconv.Atoi64(r.params["CONTENT_LENGTH"]) + if err != nil { + return nil, os.NewError("fcgi: bad CONTENT_LENGTH parameter: " + lenstr) + } + req.ContentLength = clen + } + + if req.Host != "" { + req.RawURL = "http://" + req.Host + r.params["REQUEST_URI"] + url, err := http.ParseURL(req.RawURL) + if err != nil { + return nil, os.NewError("fcgi: failed to parse host and REQUEST_URI into a URL: " + req.RawURL) + } + req.URL = url + } + if req.URL == nil { + req.RawURL = r.params["REQUEST_URI"] + url, err := http.ParseURL(req.RawURL) + if err != nil { + return nil, os.NewError("fcgi: failed to parse REQUEST_URI into a URL: " + req.RawURL) + } + req.URL = url + } + + for key, val := range r.params { + if strings.HasPrefix(key, "HTTP_") && !skipHeader[key] { + req.Header.Add(strings.Replace(key[5:], "_", "-", -1), val) + } + } + return req, nil +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + w *bufWriter + wroteHeader bool +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(data []byte) (int, os.Error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.w.Write(data) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } else if r.header.Get("Content-Type") == "" { + r.header.Set("Content-Type", "text/html; charset=utf-8") + } + + if r.header.Get("Date") == "" { + r.header.Set("Date", time.UTC().Format(http.TimeFormat)) + } + + fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) + // TODO(eds): this is duplicated in http and http/cgi + for k, vv := range r.header { + for _, v := range vv { + v = strings.Replace(v, "\n", "", -1) + v = strings.Replace(v, "\r", "", -1) + v = strings.TrimSpace(v) + fmt.Fprintf(r.w, "%s: %s\r\n", k, v) + } + } + r.w.WriteString("\r\n") +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() os.Error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler +} + +func newChild(rwc net.Conn, handler http.Handler) *child { + return &child{newConn(rwc), handler} +} + +func (c *child) serve() { + requests := map[uint16]*request{} + defer c.conn.Close() + var rec record + var br beginRequest + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + + req, ok := requests[rec.h.Id] + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + continue + } + if ok && rec.h.Type == typeBeginRequest { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return + } + + switch rec.h.Type { + case typeBeginRequest: + if err := br.read(rec.content()); err != nil { + return + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + break + } + requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + break + } + req.parseParams() + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(0, typeGetValuesResult, values) + case typeData: + // If the filter role is implemented, read the data stream here. + case typeAbortRequest: + requests[rec.h.Id] = nil, false + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if !req.keepConn { + // connection will close upon return + return + } + default: + b := make([]byte, 8) + b[0] = rec.h.Type + c.conn.writeRecord(typeUnknownType, 0, b) + } + } +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := req.httpRequest(body) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.String())) + } else { + c.handler.ServeHTTP(r, httpReq) + } + if body != nil { + body.Close() + } + r.Close() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + if !req.keepConn { + c.conn.Close() + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// service thread for each. The service threads read requests and then call handler +// to reply to them. +// If l is nil, Serve accepts connections on stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) os.Error { + if l == nil { + var err os.Error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } + panic("unreachable") +} diff --git a/src/pkg/http/fcgi/fcgi.go b/src/pkg/http/fcgi/fcgi.go new file mode 100644 index 000000000..8e2e1cd3c --- /dev/null +++ b/src/pkg/http/fcgi/fcgi.go @@ -0,0 +1,271 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// Currently only the responder role is supported. +// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "io" + "os" + "sync" +) + +const ( + // Packet Types + typeBeginRequest = iota + 1 + typeAbortRequest + typeEndRequest + typeParams + typeStdin + typeStdout + typeStderr + typeData + typeGetValues + typeGetValuesResult + typeUnknownType +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +const headerLen = 8 + +type header struct { + Version uint8 + Type uint8 + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) os.Error { + if len(content) != 8 { + return os.NewError("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err os.Error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return os.NewError("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) os.Error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(typeBeginRequest, reqId, b[:]) +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) os.Error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) os.Error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(k))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() os.Error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w, _ := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType uint8 + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, os.Error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() os.Error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/src/pkg/http/fcgi/fcgi_test.go b/src/pkg/http/fcgi/fcgi_test.go new file mode 100644 index 000000000..16a624329 --- /dev/null +++ b/src/pkg/http/fcgi/fcgi_test.go @@ -0,0 +1,114 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "io" + "os" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType uint8 + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, typeStdout, 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + []byte{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + []byte{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + []byte{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() os.Error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go index c5efffca9..17d5297b8 100644 --- a/src/pkg/http/fs.go +++ b/src/pkg/http/fs.go @@ -143,7 +143,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { n, _ := io.ReadFull(f, buf[:]) b := buf[:n] if isText(b) { - ctype = "text-plain; charset=utf-8" + ctype = "text/plain; charset=utf-8" } else { // generic binary ctype = "application/octet-stream" diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go index 692b9863e..09d0981f2 100644 --- a/src/pkg/http/fs_test.go +++ b/src/pkg/http/fs_test.go @@ -104,7 +104,7 @@ func TestServeFileContentType(t *testing.T) { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } } - get("text-plain; charset=utf-8") + get("text/plain; charset=utf-8") override = true get(ctype) } diff --git a/src/pkg/http/httptest/recorder.go b/src/pkg/http/httptest/recorder.go index 0dd19a617..f2fedefcf 100644 --- a/src/pkg/http/httptest/recorder.go +++ b/src/pkg/http/httptest/recorder.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The httptest package provides utilities for HTTP testing. +// Package httptest provides utilities for HTTP testing. package httptest import ( diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go index b93c5fe48..e4eea6815 100644 --- a/src/pkg/http/persist.go +++ b/src/pkg/http/persist.go @@ -20,8 +20,8 @@ var ( // A ServerConn reads requests and sends responses over an underlying // connection, until the HTTP keepalive logic commands an end. ServerConn -// does not close the underlying connection. Instead, the user calls Close -// and regains control over the connection. ServerConn supports pipe-lining, +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. type ServerConn struct { @@ -45,11 +45,11 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { return &ServerConn{c: c, r: r, pipereq: make(map[*Request]uint)} } -// Close detaches the ServerConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before Read has signaled the end of the keep-alive logic. The user -// should not call Close while Read or Write is in progress. -func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { sc.lk.Lock() defer sc.lk.Unlock() c = sc.c @@ -59,6 +59,15 @@ func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() os.Error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Read returns the next request on the wire. An ErrPersistEOF is returned if // it is gracefully determined that there are no more requests (e.g. after the // first request on an HTTP/1.0 connection, or after a Connection:close on a @@ -199,9 +208,9 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { } // A ClientConn sends request and receives headers over an underlying -// connection, while respecting the HTTP keepalive logic. ClientConn is not -// responsible for closing the underlying connection. One must call Close to -// regain control of that connection and deal with it as desired. +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -239,11 +248,11 @@ func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { return cc } -// Close detaches the ClientConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before the user or Read have signaled the end of the keep-alive -// logic. The user should not call Close while Read or Write is in progress. -func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { cc.lk.Lock() defer cc.lk.Unlock() c = cc.c @@ -253,6 +262,15 @@ func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() os.Error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Write writes a request. An ErrPersistEOF error is returned if the connection // has been closed in an HTTP keepalive sense. If req.Close equals true, the // keepalive connection is logically closed after this request and the opposing diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go index 7050ef5ed..308bf44b4 100644 --- a/src/pkg/http/proxy_test.go +++ b/src/pkg/http/proxy_test.go @@ -16,9 +16,15 @@ var UseProxyTests = []struct { host string match bool }{ - {"localhost", false}, // match completely + // Never proxy localhost: + {"localhost:80", false}, + {"127.0.0.1", false}, + {"127.0.0.2", false}, + {"[::1]", false}, + {"[::2]", true}, // not a loopback address + {"barbaz.net", false}, // match as .barbaz.net - {"foobar.com:443", false}, // have a port but match + {"foobar.com", false}, // have a port but match {"foofoobar.com", true}, // not match as a part of foobar.com {"baz.com", true}, // not match as a part of barbaz.com {"localhost.net", true}, // not match as suffix of address @@ -29,19 +35,16 @@ var UseProxyTests = []struct { func TestUseProxy(t *testing.T) { oldenv := os.Getenv("NO_PROXY") - no_proxy := "foobar.com, .barbaz.net , localhost" - os.Setenv("NO_PROXY", no_proxy) defer os.Setenv("NO_PROXY", oldenv) + no_proxy := "foobar.com, .barbaz.net" + os.Setenv("NO_PROXY", no_proxy) + tr := &Transport{} for _, test := range UseProxyTests { - if tr.useProxy(test.host) != test.match { - if test.match { - t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) - } else { - t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy) - } + if tr.useProxy(test.host+":80") != test.match { + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } } } diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index d82894fab..b8e9a2142 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -4,9 +4,8 @@ // HTTP Request reading and parsing. -// The http package implements parsing of HTTP requests, replies, -// and URLs and provides an extensible HTTP server and a basic -// HTTP client. +// Package http implements parsing of HTTP requests, replies, and URLs and +// provides an extensible HTTP server and a basic HTTP client. package http import ( @@ -25,12 +24,17 @@ import ( ) const ( - maxLineLength = 4096 // assumed <= bufio.defaultBufSize - maxValueLength = 4096 - maxHeaderLines = 1024 - chunkSize = 4 << 10 // 4 KB chunks + maxLineLength = 4096 // assumed <= bufio.defaultBufSize + maxValueLength = 4096 + maxHeaderLines = 1024 + chunkSize = 4 << 10 // 4 KB chunks + defaultMaxMemory = 32 << 20 // 32 MB ) +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = os.ErrorString("http: no such file") + // HTTP request parsing errors. type ProtocolError struct { os.ErrorString @@ -65,9 +69,12 @@ var reqExcludeHeader = map[string]bool{ // A Request represents a parsed HTTP request header. type Request struct { - Method string // GET, POST, PUT, etc. - RawURL string // The raw URL given in the request. - URL *URL // Parsed URL. + Method string // GET, POST, PUT, etc. + RawURL string // The raw URL given in the request. + URL *URL // Parsed URL. + + // The protocol version for incoming requests. + // Outgoing requests always use HTTP/1.1. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -134,6 +141,10 @@ type Request struct { // The parsed form. Only available after ParseForm is called. Form map[string][]string + // The parsed multipart form, including file uploads. + // Only available after ParseMultipartForm is called. + MultipartForm *multipart.Form + // Trailer maps trailer keys to values. Like for Header, if the // response has multiple trailer lines with the same key, they will be // concatenated, delimited by commas. @@ -163,9 +174,30 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartFrom. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + // MultipartReader returns a MIME multipart reader if this is a // multipart/form-data POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. func (r *Request) MultipartReader() (multipart.Reader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, os.NewError("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, os.NewError("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader() +} + +func (r *Request) multipartReader() (multipart.Reader, os.Error) { v := r.Header.Get("Content-Type") if v == "" { return nil, ErrNotMultipart @@ -199,10 +231,14 @@ const defaultUserAgent = "Go http package" // UserAgent (defaults to defaultUserAgent) // Referer // Header +// Cookie +// ContentLength +// TransferEncoding // Body // -// If Body is present, Write forces "Transfer-Encoding: chunked" as a header -// and then closes Body when finished sending it. +// If Body is present but Content-Length is <= 0, Write adds +// "Transfer-Encoding: chunked" to the header. Body is closed after +// it is sent. func (req *Request) Write(w io.Writer) os.Error { return req.write(w, false) } @@ -420,6 +456,29 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) { return n, cr.err } +// NewRequest returns a new Request given a method, URL, and optional body. +func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { + u, err := ParseURL(url) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = ioutil.NopCloser(body) + } + req := &Request{ + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + return req, nil +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { @@ -549,7 +608,9 @@ func parseQuery(m map[string][]string, query string) (err os.Error) { return err } -// ParseForm parses the request body as a form for POST requests, or the raw query for GET requests. +// ParseForm parses the raw query. +// For POST requests, it also parses the request body as a form. +// ParseMultipartForm calls ParseForm automatically. // It is idempotent. func (r *Request) ParseForm() (err os.Error) { if r.Form != nil { @@ -567,18 +628,23 @@ func (r *Request) ParseForm() (err os.Error) { ct := r.Header.Get("Content-Type") switch strings.Split(ct, ";", 2)[0] { case "text/plain", "application/x-www-form-urlencoded", "": - b, e := ioutil.ReadAll(r.Body) + const maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1)) if e != nil { if err == nil { err = e } break } + if int64(len(b)) > maxFormSize { + return os.NewError("http: POST too large") + } e = parseQuery(r.Form, string(b)) if err == nil { err = e } - // TODO(dsymonds): Handle multipart/form-data + case "multipart/form-data": + // handled by ParseMultipartForm default: return &badStringError{"unknown Content-Type", ct} } @@ -586,11 +652,50 @@ func (r *Request) ParseForm() (err os.Error) { return err } +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) os.Error { + if r.Form == nil { + err := r.ParseForm() + if err != nil { + return err + } + } + if r.MultipartForm != nil { + return nil + } + if r.MultipartForm == multipartByReader { + return os.NewError("http: multipart handled by MultipartReader") + } + + mr, err := r.multipartReader() + if err == ErrNotMultipart { + return nil + } else if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + } + r.MultipartForm = f + + return nil +} + // FormValue returns the first value for the named component of the query. -// FormValue calls ParseForm if necessary. +// FormValue calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormValue(key string) string { if r.Form == nil { - r.ParseForm() + r.ParseMultipartForm(defaultMaxMemory) } if vs := r.Form[key]; len(vs) > 0 { return vs[0] @@ -598,6 +703,25 @@ func (r *Request) FormValue(key string) string { return "" } +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, nil, os.NewError("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + return nil, nil, ErrMissingFile +} + func (r *Request) expectsContinue() bool { return strings.ToLower(r.Header.Get("Expect")) == "100-continue" } diff --git a/src/pkg/http/request_test.go b/src/pkg/http/request_test.go index 19083adf6..f982471d8 100644 --- a/src/pkg/http/request_test.go +++ b/src/pkg/http/request_test.go @@ -10,6 +10,8 @@ import ( . "http" "http/httptest" "io" + "io/ioutil" + "mime/multipart" "os" "reflect" "regexp" @@ -82,7 +84,7 @@ func TestPostQuery(t *testing.T) { req.Header = Header{ "Content-Type": {"application/x-www-form-urlencoded; boo!"}, } - req.Body = nopCloser{strings.NewReader("z=post&both=y")} + req.Body = ioutil.NopCloser(strings.NewReader("z=post&both=y")) if q := req.FormValue("q"); q != "foo" { t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) } @@ -115,7 +117,7 @@ func TestPostContentTypeParsing(t *testing.T) { req := &Request{ Method: "POST", Header: Header(test.contentType), - Body: nopCloser{bytes.NewBufferString("body")}, + Body: ioutil.NopCloser(bytes.NewBufferString("body")), } err := req.ParseForm() if !test.error && err != nil { @@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: nopCloser{new(bytes.Buffer)}, + Body: ioutil.NopCloser(new(bytes.Buffer)), } multipart, err := req.MultipartReader() if multipart == nil { @@ -170,9 +172,115 @@ func TestRedirect(t *testing.T) { } } -// TODO: stop copy/pasting this around. move to io/ioutil? -type nopCloser struct { - io.Reader +func TestMultipartRequest(t *testing.T) { + // Test that we can read the values and files of a + // multipart request with FormValue and FormFile, + // and that ParseMultipartForm can be called multiple times. + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm first call:", err) + } + defer req.MultipartForm.RemoveAll() + validateTestMultipartContents(t, req, false) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm second call:", err) + } + validateTestMultipartContents(t, req, false) +} + +func TestMultipartRequestAuto(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req := newTestMultipartRequest(t) + defer func() { + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } + }() + validateTestMultipartContents(t, req, true) +} + +func newTestMultipartRequest(t *testing.T) *Request { + b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + req, err := NewRequest("POST", "/", b) + if err != nil { + t.Fatalf("NewRequest:", err) + } + ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) + req.Header.Set("Content-type", ctype) + return req +} + +func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + + assertMem := func(n string, fd multipart.File) { + if _, ok := fd.(*os.File); ok { + t.Error(n, " is *os.File, should not be") + } + } + fd := testMultipartFile(t, req, "filea", "filea.txt", fileaContents) + assertMem("filea", fd) + fd = testMultipartFile(t, req, "fileb", "fileb.txt", filebContents) + if allMem { + assertMem("fileb", fd) + } else { + if _, ok := fd.(*os.File); !ok { + t.Errorf("fileb has unexpected underlying type %T", fd) + } + } +} + +func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { + f, fh, err := req.FormFile(key) + if err != nil { + t.Fatalf("FormFile(%q):", key, err) + } + if fh.Filename != expectFilename { + t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) + } + var b bytes.Buffer + _, err = io.Copy(&b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != expectContent { + t.Errorf("contents = %q, want %q", g, expectContent) + } + return f } -func (nopCloser) Close() os.Error { return nil } +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index 726baa266..bb000c701 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -6,7 +6,10 @@ package http import ( "bytes" + "io" "io/ioutil" + "os" + "strings" "testing" ) @@ -133,6 +136,41 @@ var reqWriteTests = []reqWriteTest{ "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", }, + + // HTTP/1.1 POST with Content-Length, no chunking + { + Request{ + Method: "POST", + URL: &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + ContentLength: 6, + }, + + []byte("abcdef"), + + "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + "POST http://www.google.com/search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + // default to HTTP/1.1 { Request{ @@ -189,3 +227,26 @@ func TestRequestWrite(t *testing.T) { } } } + +type closeChecker struct { + io.Reader + closed bool +} + +func (rc *closeChecker) Close() os.Error { + rc.closed = true + return nil +} + +// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer +// inside a NopCloser. +func TestRequestWriteClosesBody(t *testing.T) { + rc := &closeChecker{Reader: strings.NewReader("my body")} + req, _ := NewRequest("GET", "http://foo.com/", rc) + buf := new(bytes.Buffer) + req.Write(buf) + if !rc.closed { + t.Error("body not closed after write") + } +} diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go index 314f05b36..9e77c20c4 100644 --- a/src/pkg/http/response_test.go +++ b/src/pkg/http/response_test.go @@ -7,8 +7,12 @@ package http import ( "bufio" "bytes" + "compress/gzip" + "crypto/rand" "fmt" + "os" "io" + "io/ioutil" "reflect" "testing" ) @@ -117,7 +121,9 @@ var respTests = []respTest{ "Transfer-Encoding: chunked\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + + "09\r\n" + + "continued\r\n" + "0\r\n" + "\r\n", @@ -134,7 +140,7 @@ var respTests = []respTest{ TransferEncoding: []string{"chunked"}, }, - "Body here\n", + "Body here\ncontinued", }, // Chunked response with Content-Length. @@ -186,6 +192,29 @@ var respTests = []respTest{ "", }, + // explicit Content-Length of 0. + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RequestMethod: "GET", + Header: Header{ + "Content-Length": {"0"}, + }, + Close: false, + ContentLength: 0, + }, + + "", + }, + // Status line without a Reason-Phrase, but trailing space. // (permitted by RFC 2616) { @@ -250,9 +279,107 @@ func TestReadResponse(t *testing.T) { } } +var readResponseCloseInMiddleTests = []struct { + chunked, compressed bool +}{ + {false, false}, + {true, false}, + {true, true}, +} + +// TestReadResponseCloseInMiddle tests that closing a body after +// reading only part of its contents advances the read to the end of +// the request, right up until the next request. +func TestReadResponseCloseInMiddle(t *testing.T) { + for _, test := range readResponseCloseInMiddleTests { + fatalf := func(format string, args ...interface{}) { + args = append([]interface{}{test.chunked, test.compressed}, args...) + t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) + } + checkErr := func(err os.Error, msg string) { + if err == nil { + return + } + fatalf(msg+": %v", err) + } + var buf bytes.Buffer + buf.WriteString("HTTP/1.1 200 OK\r\n") + if test.chunked { + buf.WriteString("Transfer-Encoding: chunked\r\n") + } else { + buf.WriteString("Content-Length: 1000000\r\n") + } + var wr io.Writer = &buf + if test.chunked { + wr = &chunkedWriter{wr} + } + if test.compressed { + buf.WriteString("Content-Encoding: gzip\r\n") + var err os.Error + wr, err = gzip.NewWriter(wr) + checkErr(err, "gzip.NewWriter") + } + buf.WriteString("\r\n") + + chunk := bytes.Repeat([]byte{'x'}, 1000) + for i := 0; i < 1000; i++ { + if test.compressed { + // Otherwise this compresses too well. + _, err := io.ReadFull(rand.Reader, chunk) + checkErr(err, "rand.Reader ReadFull") + } + wr.Write(chunk) + } + if test.compressed { + err := wr.(*gzip.Compressor).Close() + checkErr(err, "compressor close") + } + if test.chunked { + buf.WriteString("0\r\n\r\n") + } + buf.WriteString("Next Request Here") + + bufr := bufio.NewReader(&buf) + resp, err := ReadResponse(bufr, "GET") + checkErr(err, "ReadResponse") + expectedLength := int64(-1) + if !test.chunked { + expectedLength = 1000000 + } + if resp.ContentLength != expectedLength { + fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength) + } + if resp.Body == nil { + fatalf("nil body") + } + if test.compressed { + gzReader, err := gzip.NewReader(resp.Body) + checkErr(err, "gzip.NewReader") + resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + } + + rbuf := make([]byte, 2500) + n, err := io.ReadFull(resp.Body, rbuf) + checkErr(err, "2500 byte ReadFull") + if n != 2500 { + fatalf("ReadFull only read %d bytes", n) + } + if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) { + fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf)) + } + resp.Body.Close() + + rest, err := ioutil.ReadAll(bufr) + checkErr(err, "ReadAll on remainder") + if e, g := "Next Request Here", string(rest); e != g { + fatalf("for chunked=%v remainder = %q, expected %q", g, e) + } + } +} + func diff(t *testing.T, prefix string, have, want interface{}) { - hv := reflect.NewValue(have).Elem() - wv := reflect.NewValue(want).Elem() + hv := reflect.ValueOf(have).Elem() + wv := reflect.ValueOf(want).Elem() if hv.Type() != wv.Type() { t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type()) } diff --git a/src/pkg/http/reverseproxy.go b/src/pkg/http/reverseproxy.go new file mode 100644 index 000000000..e4ce1e34c --- /dev/null +++ b/src/pkg/http/reverseproxy.go @@ -0,0 +1,100 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package http + +import ( + "io" + "log" + "net" + "strings" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*Request) + + // The Transport used to perform proxy requests. + // If nil, DefaultTransport is used. + Transport RoundTripper +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *URL) *ReverseProxy { + director := func(req *Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if q := req.URL.RawQuery; q != "" { + req.URL.RawPath = req.URL.Path + "?" + q + } else { + req.URL.RawPath = req.URL.Path + } + req.URL.RawQuery = target.RawQuery + } + return &ReverseProxy{Director: director} +} + +func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) { + transport := p.Transport + if transport == nil { + transport = DefaultTransport + } + + outreq := new(Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + outreq.Header.Set("X-Forwarded-For", clientIp) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + log.Printf("http: proxy error: %v", err) + rw.WriteHeader(StatusInternalServerError) + return + } + + hdr := rw.Header() + for k, vv := range res.Header { + for _, v := range vv { + hdr.Add(k, v) + } + } + + rw.WriteHeader(res.StatusCode) + + if res.Body != nil { + io.Copy(rw, res.Body) + } +} diff --git a/src/pkg/http/reverseproxy_test.go b/src/pkg/http/reverseproxy_test.go new file mode 100644 index 000000000..8cf7705d7 --- /dev/null +++ b/src/pkg/http/reverseproxy_test.go @@ -0,0 +1,50 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package http_test + +import ( + . "http" + "http/httptest" + "io/ioutil" + "testing" +) + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + w.Header().Set("X-Foo", "bar") + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := ParseURL(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, _, err := Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index 0142dead9..c3c7b8d33 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -247,7 +247,7 @@ func TestServerTimeouts(t *testing.T) { server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second} go server.Serve(l) - url := fmt.Sprintf("http://localhost:%d/", addr.Port) + url := fmt.Sprintf("http://%s/", addr) // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test @@ -265,7 +265,7 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Nanoseconds() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port)) + conn, err := net.Dial("tcp", addr.String()) if err != nil { t.Fatalf("Dial: %v", err) } @@ -588,7 +588,7 @@ func TestServerExpect(t *testing.T) { sendf := func(format string, args ...interface{}) { _, err := fmt.Fprintf(conn, format, args...) if err != nil { - t.Fatalf("Error writing %q: %v", format, err) + t.Fatalf("On test %#v, error writing %q: %v", test, format, err) } } go func() { @@ -616,3 +616,100 @@ func TestServerExpect(t *testing.T) { runTest(test) } } + +func TestServerConsumesRequestBody(t *testing.T) { + log := make(chan string, 100) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log <- "got_request" + w.WriteHeader(StatusOK) + log <- "wrote_header" + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + bufr := bufio.NewReader(conn) + gotres := make(chan bool) + go func() { + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatal(err) + } + log <- line + gotres <- true + }() + + size := 1 << 20 + log <- "writing_request" + fmt.Fprintf(conn, "POST / HTTP/1.0\r\nContent-Length: %d\r\n\r\n", size) + time.Sleep(25e6) // give server chance to misbehave & speak out of turn + log <- "slept_after_req_headers" + conn.Write([]byte(strings.Repeat("a", size))) + + <-gotres + expected := []string{ + "writing_request", "got_request", + "slept_after_req_headers", "wrote_header", + "HTTP/1.0 200 OK\r\n"} + for step, e := range expected { + if g := <-log; e != g { + t.Errorf("on step %d expected %q, got %q", step, e, g) + } + } +} + +func TestTimeoutHandler(t *testing.T) { + sendHi := make(chan bool, 1) + writeErrors := make(chan os.Error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan int64, 1) // write to this to force timeouts + ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) + defer ts.Close() + + // Succeed without timing out: + sendHi <- true + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- 1 + res, _, err = Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "<title>Timeout</title>") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index 3291de101..96d2cb638 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -22,6 +22,7 @@ import ( "path" "strconv" "strings" + "sync" "time" ) @@ -141,9 +142,13 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { type expectContinueReader struct { resp *response readCloser io.ReadCloser + closed bool } func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { + if ecr.closed { + return 0, os.NewError("http: Read after Close on request Body") + } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") @@ -153,6 +158,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { } func (ecr *expectContinueReader) Close() os.Error { + ecr.closed = true return ecr.readCloser.Close() } @@ -196,6 +202,16 @@ func (w *response) WriteHeader(code int) { log.Print("http: multiple response.WriteHeader calls") return } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. + if w.req.ContentLength != 0 { + ecr, isExpecter := w.req.Body.(*expectContinueReader) + if !isExpecter || ecr.resp.wroteContinue { + w.req.Body.Close() + } + } + w.wroteHeader = true w.status = code if code == StatusNotModified { @@ -407,6 +423,9 @@ func (w *response) finishRequest() { } w.conn.buf.Flush() w.req.Body.Close() + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } if w.contentLength != -1 && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. @@ -883,3 +902,89 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han tlsListener := tls.NewListener(conn, config) return Serve(tlsListener, handler) } + +// TimeoutHandler returns a Handler that runs h with the given time limit. +// +// The new Handler calls h.ServeHTTP to handle each request, but if a +// call runs for more than ns nanoseconds, the handler responds with +// a 503 Service Unavailable error and the given message in its body. +// (If msg is empty, a suitable default message will be sent.) +// After such a timeout, writes by h to its ResponseWriter will return +// ErrHandlerTimeout. +func TimeoutHandler(h Handler, ns int64, msg string) Handler { + f := func() <-chan int64 { + return time.After(ns) + } + return &timeoutHandler{h, f, msg} +} + +// ErrHandlerTimeout is returned on ResponseWriter Write calls +// in handlers which have timed out. +var ErrHandlerTimeout = os.NewError("http: Handler timeout") + +type timeoutHandler struct { + handler Handler + timeout func() <-chan int64 // returns channel producing a timeout + body string +} + +func (h *timeoutHandler) errorBody() string { + if h.body != "" { + return h.body + } + return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>" +} + +func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + done := make(chan bool) + tw := &timeoutWriter{w: w} + go func() { + h.handler.ServeHTTP(tw, r) + done <- true + }() + select { + case <-done: + return + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader { + tw.w.WriteHeader(StatusServiceUnavailable) + tw.w.Write([]byte(h.errorBody())) + } + tw.timedOut = true + } +} + +type timeoutWriter struct { + w ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool +} + +func (tw *timeoutWriter) Header() Header { + return tw.w.Header() +} + +func (tw *timeoutWriter) Write(p []byte) (int, os.Error) { + tw.mu.Lock() + timedOut := tw.timedOut + tw.mu.Unlock() + if timedOut { + return 0, ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + if tw.timedOut || tw.wroteHeader { + tw.mu.Unlock() + return + } + tw.wroteHeader = true + tw.mu.Unlock() + tw.w.WriteHeader(code) +} diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go index 41614f144..98c32bab6 100644 --- a/src/pkg/http/transfer.go +++ b/src/pkg/http/transfer.go @@ -7,6 +7,7 @@ package http import ( "bufio" "io" + "io/ioutil" "os" "strconv" "strings" @@ -447,17 +448,10 @@ func (b *body) Close() os.Error { return nil } - trashBuf := make([]byte, 1024) // local for thread safety - for { - _, err := b.Read(trashBuf) - if err == nil { - continue - } - if err == os.EOF { - break - } + if _, err := io.Copy(ioutil.Discard, b); err != nil { return err } + if b.hdr == nil { // not reading trailer return nil } diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index 7fa37af3b..73a2c2191 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -6,6 +6,7 @@ package http import ( "bufio" + "bytes" "compress/gzip" "crypto/tls" "encoding/base64" @@ -217,6 +218,9 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { conn, err := net.Dial("tcp", cm.addr()) if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } return nil, err } @@ -288,10 +292,28 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. +// addr is always a canonicalAddr with a host and port. func (t *Transport) useProxy(addr string) bool { if len(addr) == 0 { return true } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + if host == "localhost" { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { + // 127.0.0.0/8 loopback isn't proxied. + return false + } + if bytes.Equal(ip, net.IPv6loopback) { + return false + } + } + no_proxy := t.getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false @@ -510,12 +532,13 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { re.res.Header.Del("Content-Encoding") re.res.Header.Del("Content-Length") re.res.ContentLength = -1 - var err os.Error - re.res.Body, err = gzip.NewReader(re.res.Body) + esb := re.res.Body.(*bodyEOFSignal) + gzReader, err := gzip.NewReader(esb.body) if err != nil { pc.close() return nil, err } + esb.body = &readFirstCloseBoth{gzReader, esb.body} } return re.res, re.err @@ -554,7 +577,7 @@ func responseIsKeepAlive(res *Response) bool { func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { resp, err = ReadResponse(r, requestMethod) if err == nil && resp.ContentLength != 0 { - resp.Body = &bodyEOFSignal{resp.Body, nil} + resp.Body = &bodyEOFSignal{body: resp.Body} } return } @@ -563,12 +586,16 @@ func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Res // once, right before the final Read() or Close() call returns, but after // EOF has been seen. type bodyEOFSignal struct { - body io.ReadCloser - fn func() + body io.ReadCloser + fn func() + isClosed bool } func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { n, err = es.body.Read(p) + if es.isClosed && n > 0 { + panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + } if err == os.EOF && es.fn != nil { es.fn() es.fn = nil @@ -577,6 +604,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { } func (es *bodyEOFSignal) Close() (err os.Error) { + es.isClosed = true err = es.body.Close() if err == nil && es.fn != nil { es.fn() @@ -584,3 +612,19 @@ func (es *bodyEOFSignal) Close() (err os.Error) { } return } + +type readFirstCloseBoth struct { + io.ReadCloser + io.Closer +} + +func (r *readFirstCloseBoth) Close() os.Error { + if err := r.ReadCloser.Close(); err != nil { + r.Closer.Close() + return err + } + if err := r.Closer.Close(); err != nil { + return err + } + return nil +} diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go index f83deedfc..a32ac4c4f 100644 --- a/src/pkg/http/transport_test.go +++ b/src/pkg/http/transport_test.go @@ -9,11 +9,14 @@ package http_test import ( "bytes" "compress/gzip" + "crypto/rand" "fmt" . "http" "http/httptest" + "io" "io/ioutil" "os" + "strconv" "testing" "time" ) @@ -179,35 +182,47 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - ch := make(chan string) + resch := make(chan string) + gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - w.Write([]byte(<-ch)) + gotReq <- true + msg := <-resch + _, err := w.Write([]byte(msg)) + if err != nil { + t.Fatalf("Write: %v", err) + } })) defer ts.Close() maxIdleConns := 2 tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} c := &Client{Transport: tr} - // Start 3 outstanding requests (will hang until we write to - // ch) + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we we write to resch, though. donech := make(chan bool) doReq := func() { resp, _, err := c.Get(ts.URL) if err != nil { t.Error(err) } - ioutil.ReadAll(resp.Body) + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } donech <- true } go doReq() + <-gotReq go doReq() + <-gotReq go doReq() + <-gotReq if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) } - ch <- "res1" + resch <- "res1" <-donech keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { @@ -221,13 +236,13 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } - ch <- "res2" + resch <- "res2" <-donech if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after second response, expected %d idle conns; got %d", e, g) } - ch <- "res3" + resch <- "res3" <-donech if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after third response, still expected %d idle conns; got %d", e, g) @@ -355,32 +370,80 @@ func TestTransportNilURL(t *testing.T) { func TestTransportGzip(t *testing.T) { const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - if g, e := r.Header.Get("Accept-Encoding"), "gzip"; g != e { + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { t.Errorf("Accept-Encoding = %q, want %q", g, e) } - w.Header().Set("Content-Encoding", "gzip") + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } gz, _ := gzip.NewWriter(w) - defer gz.Close() gz.Write([]byte(testString)) - + if req.FormValue("body") == "large" { + io.Copyn(gz, rand.Reader, nRandBytes) + } + gz.Close() })) defer ts.Close() - c := &Client{Transport: &Transport{}} - res, _, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if g, e := string(body), testString; g != e { - t.Fatalf("body = %q; want %q", g, e) - } - if g, e := res.Header.Get("Content-Encoding"), ""; g != e { - t.Fatalf("Content-Encoding = %q; want %q", g, e) + for _, chunked := range []string{"1", "0"} { + c := &Client{Transport: &Transport{}} + + // First fetch something large, but only read some of it. + res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, _, err = c.Get(ts.URL + "?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } } } diff --git a/src/pkg/image/image.go b/src/pkg/image/image.go index c0e96e1f7..5f398a304 100644 --- a/src/pkg/image/image.go +++ b/src/pkg/image/image.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The image package implements a basic 2-D image library. +// Package image implements a basic 2-D image library. package image // A Config consists of an image's color model and dimensions. diff --git a/src/pkg/image/jpeg/Makefile b/src/pkg/image/jpeg/Makefile index 5c5f97e71..d9d830f2f 100644 --- a/src/pkg/image/jpeg/Makefile +++ b/src/pkg/image/jpeg/Makefile @@ -6,8 +6,10 @@ include ../../../Make.inc TARG=image/jpeg GOFILES=\ + fdct.go\ huffman.go\ idct.go\ reader.go\ + writer.go\ include ../../../Make.pkg diff --git a/src/pkg/image/jpeg/fdct.go b/src/pkg/image/jpeg/fdct.go new file mode 100644 index 000000000..3f8be4e32 --- /dev/null +++ b/src/pkg/image/jpeg/fdct.go @@ -0,0 +1,190 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +// This file implements a Forward Discrete Cosine Transformation. + +/* +It is based on the code in jfdctint.c from the Independent JPEG Group, +found at http://www.ijg.org/files/jpegsrc.v8c.tar.gz. + +The "LEGAL ISSUES" section of the README in that archive says: + +In plain English: + +1. We don't promise that this software works. (But if you find any bugs, + please let us know!) +2. You can use this software for whatever you want. You don't have to pay us. +3. You may not pretend that you wrote this software. If you use it in a + program, you must acknowledge somewhere in your documentation that + you've used the IJG code. + +In legalese: + +The authors make NO WARRANTY or representation, either express or implied, +with respect to this software, its quality, accuracy, merchantability, or +fitness for a particular purpose. This software is provided "AS IS", and you, +its user, assume the entire risk as to its quality and accuracy. + +This software is copyright (C) 1991-2011, Thomas G. Lane, Guido Vollbeding. +All Rights Reserved except as specified below. + +Permission is hereby granted to use, copy, modify, and distribute this +software (or portions thereof) for any purpose, without fee, subject to these +conditions: +(1) If any part of the source code for this software is distributed, then this +README file must be included, with this copyright and no-warranty notice +unaltered; and any additions, deletions, or changes to the original files +must be clearly indicated in accompanying documentation. +(2) If only executable code is distributed, then the accompanying +documentation must state that "this software is based in part on the work of +the Independent JPEG Group". +(3) Permission for use of this software is granted only if the user accepts +full responsibility for any undesirable consequences; the authors accept +NO LIABILITY for damages of any kind. + +These conditions apply to any software derived from or based on the IJG code, +not just to the unmodified library. If you use our work, you ought to +acknowledge us. + +Permission is NOT granted for the use of any IJG author's name or company name +in advertising or publicity relating to this software or products derived from +it. This software may be referred to only as "the Independent JPEG Group's +software". + +We specifically permit and encourage the use of this software as the basis of +commercial products, provided that all warranty or liability claims are +assumed by the product vendor. +*/ + +// Trigonometric constants in 13-bit fixed point format. +const ( + fix_0_298631336 = 2446 + fix_0_390180644 = 3196 + fix_0_541196100 = 4433 + fix_0_765366865 = 6270 + fix_0_899976223 = 7373 + fix_1_175875602 = 9633 + fix_1_501321110 = 12299 + fix_1_847759065 = 15137 + fix_1_961570560 = 16069 + fix_2_053119869 = 16819 + fix_2_562915447 = 20995 + fix_3_072711026 = 25172 +) + +const ( + constBits = 13 + pass1Bits = 2 + centerJSample = 128 +) + +// fdct performs a forward DCT on an 8x8 block of coefficients, including a +// level shift. +func fdct(b *block) { + // Pass 1: process rows. + for y := 0; y < 8; y++ { + x0 := b[y*8+0] + x1 := b[y*8+1] + x2 := b[y*8+2] + x3 := b[y*8+3] + x4 := b[y*8+4] + x5 := b[y*8+5] + x6 := b[y*8+6] + x7 := b[y*8+7] + + tmp0 := x0 + x7 + tmp1 := x1 + x6 + tmp2 := x2 + x5 + tmp3 := x3 + x4 + + tmp10 := tmp0 + tmp3 + tmp12 := tmp0 - tmp3 + tmp11 := tmp1 + tmp2 + tmp13 := tmp1 - tmp2 + + tmp0 = x0 - x7 + tmp1 = x1 - x6 + tmp2 = x2 - x5 + tmp3 = x3 - x4 + + b[y*8+0] = (tmp10 + tmp11 - 8*centerJSample) << pass1Bits + b[y*8+4] = (tmp10 - tmp11) << pass1Bits + z1 := (tmp12 + tmp13) * fix_0_541196100 + z1 += 1 << (constBits - pass1Bits - 1) + b[y*8+2] = (z1 + tmp12*fix_0_765366865) >> (constBits - pass1Bits) + b[y*8+6] = (z1 - tmp13*fix_1_847759065) >> (constBits - pass1Bits) + + tmp10 = tmp0 + tmp3 + tmp11 = tmp1 + tmp2 + tmp12 = tmp0 + tmp2 + tmp13 = tmp1 + tmp3 + z1 = (tmp12 + tmp13) * fix_1_175875602 + z1 += 1 << (constBits - pass1Bits - 1) + tmp0 = tmp0 * fix_1_501321110 + tmp1 = tmp1 * fix_3_072711026 + tmp2 = tmp2 * fix_2_053119869 + tmp3 = tmp3 * fix_0_298631336 + tmp10 = tmp10 * -fix_0_899976223 + tmp11 = tmp11 * -fix_2_562915447 + tmp12 = tmp12 * -fix_0_390180644 + tmp13 = tmp13 * -fix_1_961570560 + + tmp12 += z1 + tmp13 += z1 + b[y*8+1] = (tmp0 + tmp10 + tmp12) >> (constBits - pass1Bits) + b[y*8+3] = (tmp1 + tmp11 + tmp13) >> (constBits - pass1Bits) + b[y*8+5] = (tmp2 + tmp11 + tmp12) >> (constBits - pass1Bits) + b[y*8+7] = (tmp3 + tmp10 + tmp13) >> (constBits - pass1Bits) + } + // Pass 2: process columns. + // We remove pass1Bits scaling, but leave results scaled up by an overall factor of 8. + for x := 0; x < 8; x++ { + tmp0 := b[0*8+x] + b[7*8+x] + tmp1 := b[1*8+x] + b[6*8+x] + tmp2 := b[2*8+x] + b[5*8+x] + tmp3 := b[3*8+x] + b[4*8+x] + + tmp10 := tmp0 + tmp3 + 1<<(pass1Bits-1) + tmp12 := tmp0 - tmp3 + tmp11 := tmp1 + tmp2 + tmp13 := tmp1 - tmp2 + + tmp0 = b[0*8+x] - b[7*8+x] + tmp1 = b[1*8+x] - b[6*8+x] + tmp2 = b[2*8+x] - b[5*8+x] + tmp3 = b[3*8+x] - b[4*8+x] + + b[0*8+x] = (tmp10 + tmp11) >> pass1Bits + b[4*8+x] = (tmp10 - tmp11) >> pass1Bits + + z1 := (tmp12 + tmp13) * fix_0_541196100 + z1 += 1 << (constBits + pass1Bits - 1) + b[2*8+x] = (z1 + tmp12*fix_0_765366865) >> (constBits + pass1Bits) + b[6*8+x] = (z1 - tmp13*fix_1_847759065) >> (constBits + pass1Bits) + + tmp10 = tmp0 + tmp3 + tmp11 = tmp1 + tmp2 + tmp12 = tmp0 + tmp2 + tmp13 = tmp1 + tmp3 + z1 = (tmp12 + tmp13) * fix_1_175875602 + z1 += 1 << (constBits + pass1Bits - 1) + tmp0 = tmp0 * fix_1_501321110 + tmp1 = tmp1 * fix_3_072711026 + tmp2 = tmp2 * fix_2_053119869 + tmp3 = tmp3 * fix_0_298631336 + tmp10 = tmp10 * -fix_0_899976223 + tmp11 = tmp11 * -fix_2_562915447 + tmp12 = tmp12 * -fix_0_390180644 + tmp13 = tmp13 * -fix_1_961570560 + + tmp12 += z1 + tmp13 += z1 + b[1*8+x] = (tmp0 + tmp10 + tmp12) >> (constBits + pass1Bits) + b[3*8+x] = (tmp1 + tmp11 + tmp13) >> (constBits + pass1Bits) + b[5*8+x] = (tmp2 + tmp11 + tmp12) >> (constBits + pass1Bits) + b[7*8+x] = (tmp3 + tmp10 + tmp13) >> (constBits + pass1Bits) + } +} diff --git a/src/pkg/image/jpeg/idct.go b/src/pkg/image/jpeg/idct.go index 518993110..e5a2f40f5 100644 --- a/src/pkg/image/jpeg/idct.go +++ b/src/pkg/image/jpeg/idct.go @@ -63,7 +63,7 @@ const ( // // For more on the actual algorithm, see Z. Wang, "Fast algorithms for the discrete W transform and // for the discrete Fourier transform", IEEE Trans. on ASSP, Vol. ASSP- 32, pp. 803-816, Aug. 1984. -func idct(b *[blockSize]int) { +func idct(b *block) { // Horizontal 1-D IDCT. for y := 0; y < 8; y++ { // If all the AC components are zero, then the IDCT is trivial. diff --git a/src/pkg/image/jpeg/reader.go b/src/pkg/image/jpeg/reader.go index fb9cb11bb..21a6fff96 100644 --- a/src/pkg/image/jpeg/reader.go +++ b/src/pkg/image/jpeg/reader.go @@ -2,18 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The jpeg package implements a decoder for JPEG images, as defined in ITU-T T.81. +// Package jpeg implements a JPEG image decoder and encoder. +// +// JPEG is defined in ITU-T T.81: http://www.w3.org/Graphics/JPEG/itu-t81.pdf. package jpeg -// See http://www.w3.org/Graphics/JPEG/itu-t81.pdf - import ( "bufio" "image" + "image/ycbcr" "io" "os" ) +// TODO(nigeltao): fix up the doc comment style so that sentences start with +// the name of the type or function that they annotate. + // A FormatError reports that the input is not a valid JPEG. type FormatError string @@ -26,12 +30,14 @@ func (e UnsupportedError) String() string { return "unsupported JPEG feature: " // Component specification, specified in section B.2.2. type component struct { + h int // Horizontal sampling factor. + v int // Vertical sampling factor. c uint8 // Component identifier. - h uint8 // Horizontal sampling factor. - v uint8 // Vertical sampling factor. tq uint8 // Quantization table destination selector. } +type block [blockSize]int + const ( blockSize = 64 // A DCT block is 8x8. @@ -84,13 +90,13 @@ type Reader interface { type decoder struct { r Reader width, height int - image *image.RGBA + img *ycbcr.YCbCr ri int // Restart Interval. comps [nComponent]component huff [maxTc + 1][maxTh + 1]huffman - quant [maxTq + 1][blockSize]int + quant [maxTq + 1]block b bits - blocks [nComponent][maxH * maxV][blockSize]int + blocks [nComponent][maxH * maxV]block tmp [1024]byte } @@ -130,9 +136,9 @@ func (d *decoder) processSOF(n int) os.Error { } for i := 0; i < nComponent; i++ { hv := d.tmp[7+3*i] + d.comps[i].h = int(hv >> 4) + d.comps[i].v = int(hv & 0x0f) d.comps[i].c = d.tmp[6+3*i] - d.comps[i].h = hv >> 4 - d.comps[i].v = hv & 0x0f d.comps[i].tq = d.tmp[8+3*i] // We only support YCbCr images, and 4:4:4, 4:2:2 or 4:2:0 chroma downsampling ratios. This implies that // the (h, v) values for the Y component are either (1, 1), (2, 1) or (2, 2), and the @@ -176,71 +182,47 @@ func (d *decoder) processDQT(n int) os.Error { return nil } -// Set the Pixel (px, py)'s RGB value, based on its YCbCr value. -func (d *decoder) calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex int) { - y, cb, cr := d.blocks[0][lumaBlock][lumaIndex], d.blocks[1][0][chromaIndex], d.blocks[2][0][chromaIndex] - // The JFIF specification (http://www.w3.org/Graphics/JPEG/jfif3.pdf, page 3) gives the formula - // for translating YCbCr to RGB as: - // R = Y + 1.402 (Cr-128) - // G = Y - 0.34414 (Cb-128) - 0.71414 (Cr-128) - // B = Y + 1.772 (Cb-128) - yPlusHalf := 100000*y + 50000 - cb -= 128 - cr -= 128 - r := (yPlusHalf + 140200*cr) / 100000 - g := (yPlusHalf - 34414*cb - 71414*cr) / 100000 - b := (yPlusHalf + 177200*cb) / 100000 - if r < 0 { - r = 0 - } else if r > 255 { - r = 255 +// Clip x to the range [0, 255] inclusive. +func clip(x int) uint8 { + if x < 0 { + return 0 } - if g < 0 { - g = 0 - } else if g > 255 { - g = 255 + if x > 255 { + return 255 } - if b < 0 { - b = 0 - } else if b > 255 { - b = 255 - } - d.image.Pix[py*d.image.Stride+px] = image.RGBAColor{uint8(r), uint8(g), uint8(b), 0xff} + return uint8(x) } -// Convert the MCU from YCbCr to RGB. -func (d *decoder) convertMCU(mx, my, h0, v0 int) { - lumaBlock := 0 +// Store the MCU to the image. +func (d *decoder) storeMCU(mx, my int) { + h0, v0 := d.comps[0].h, d.comps[0].v + // Store the luma blocks. for v := 0; v < v0; v++ { for h := 0; h < h0; h++ { - chromaBase := 8*4*v + 4*h - py := 8 * (v0*my + v) - for y := 0; y < 8 && py < d.height; y++ { - px := 8 * (h0*mx + h) - lumaIndex := 8 * y - chromaIndex := chromaBase + 8*(y/v0) - for x := 0; x < 8 && px < d.width; x++ { - d.calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex) - if h0 == 1 { - chromaIndex += 1 - } else { - chromaIndex += x % 2 - } - lumaIndex++ - px++ + p := 8 * ((v0*my+v)*d.img.YStride + (h0*mx + h)) + for y := 0; y < 8; y++ { + for x := 0; x < 8; x++ { + d.img.Y[p] = clip(d.blocks[0][h0*v+h][8*y+x]) + p++ } - py++ + p += d.img.YStride - 8 } - lumaBlock++ } } + // Store the chroma blocks. + p := 8 * (my*d.img.CStride + mx) + for y := 0; y < 8; y++ { + for x := 0; x < 8; x++ { + d.img.Cb[p] = clip(d.blocks[1][0][8*y+x]) + d.img.Cr[p] = clip(d.blocks[2][0][8*y+x]) + p++ + } + p += d.img.CStride - 8 + } } // Specified in section B.2.3. func (d *decoder) processSOS(n int) os.Error { - if d.image == nil { - d.image = image.NewRGBA(d.width, d.height) - } if n != 4+2*nComponent { return UnsupportedError("SOS has wrong length") } @@ -255,7 +237,6 @@ func (d *decoder) processSOS(n int) os.Error { td uint8 // DC table selector. ta uint8 // AC table selector. } - h0, v0 := int(d.comps[0].h), int(d.comps[0].v) // The h and v values from the Y components. for i := 0; i < nComponent; i++ { cs := d.tmp[1+2*i] // Component selector. if cs != d.comps[i].c { @@ -265,17 +246,42 @@ func (d *decoder) processSOS(n int) os.Error { scanComps[i].ta = d.tmp[2+2*i] & 0x0f } // mxx and myy are the number of MCUs (Minimum Coded Units) in the image. - mxx := (d.width + 8*int(h0) - 1) / (8 * int(h0)) - myy := (d.height + 8*int(v0) - 1) / (8 * int(v0)) + h0, v0 := d.comps[0].h, d.comps[0].v // The h and v values from the Y components. + mxx := (d.width + 8*h0 - 1) / (8 * h0) + myy := (d.height + 8*v0 - 1) / (8 * v0) + if d.img == nil { + var subsampleRatio ycbcr.SubsampleRatio + n := h0 * v0 + switch n { + case 1: + subsampleRatio = ycbcr.SubsampleRatio444 + case 2: + subsampleRatio = ycbcr.SubsampleRatio422 + case 4: + subsampleRatio = ycbcr.SubsampleRatio420 + default: + panic("unreachable") + } + b := make([]byte, mxx*myy*(1*8*8*n+2*8*8)) + d.img = &ycbcr.YCbCr{ + Y: b[mxx*myy*(0*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+0*8*8)], + Cb: b[mxx*myy*(1*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+1*8*8)], + Cr: b[mxx*myy*(1*8*8*n+1*8*8) : mxx*myy*(1*8*8*n+2*8*8)], + SubsampleRatio: subsampleRatio, + YStride: mxx * 8 * h0, + CStride: mxx * 8, + Rect: image.Rect(0, 0, d.width, d.height), + } + } mcu, expectedRST := 0, uint8(rst0Marker) - var allZeroes [blockSize]int + var allZeroes block var dc [nComponent]int for my := 0; my < myy; my++ { for mx := 0; mx < mxx; mx++ { for i := 0; i < nComponent; i++ { qt := &d.quant[d.comps[i].tq] - for j := 0; j < int(d.comps[i].h*d.comps[i].v); j++ { + for j := 0; j < d.comps[i].h*d.comps[i].v; j++ { d.blocks[i][j] = allZeroes // Decode the DC coefficient, as specified in section F.2.2.1. @@ -299,20 +305,20 @@ func (d *decoder) processSOS(n int) os.Error { if err != nil { return err } - v0 := value >> 4 - v1 := value & 0x0f - if v1 != 0 { - k += int(v0) + val0 := value >> 4 + val1 := value & 0x0f + if val1 != 0 { + k += int(val0) if k > blockSize { return FormatError("bad DCT index") } - ac, err := d.receiveExtend(v1) + ac, err := d.receiveExtend(val1) if err != nil { return err } d.blocks[i][j][unzig[k]] = ac * qt[k] } else { - if v0 != 0x0f { + if val0 != 0x0f { break } k += 0x0f @@ -322,7 +328,7 @@ func (d *decoder) processSOS(n int) os.Error { idct(&d.blocks[i][j]) } // for j } // for i - d.convertMCU(mx, my, int(d.comps[0].h), int(d.comps[0].v)) + d.storeMCU(mx, my) mcu++ if d.ri > 0 && mcu%d.ri == 0 && mcu < mxx*myy { // A more sophisticated decoder could use RST[0-7] markers to resynchronize from corrupt input, @@ -431,7 +437,7 @@ func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, os.Error) { return nil, err } } - return d.image, nil + return d.img, nil } // Decode reads a JPEG image from r and returns it as an image.Image. diff --git a/src/pkg/image/jpeg/writer.go b/src/pkg/image/jpeg/writer.go new file mode 100644 index 000000000..505cce04f --- /dev/null +++ b/src/pkg/image/jpeg/writer.go @@ -0,0 +1,523 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "bufio" + "image" + "image/ycbcr" + "io" + "os" +) + +// min returns the minimum of two integers. +func min(x, y int) int { + if x < y { + return x + } + return y +} + +// div returns a/b rounded to the nearest integer, instead of rounded to zero. +func div(a int, b int) int { + if a >= 0 { + return (a + (b >> 1)) / b + } + return -((-a + (b >> 1)) / b) +} + +// bitCount counts the number of bits needed to hold an integer. +var bitCount = [256]byte{ + 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, +} + +type quantIndex int + +const ( + quantIndexLuminance quantIndex = iota + quantIndexChrominance + nQuantIndex +) + +// unscaledQuant are the unscaled quantization tables. Each encoder copies and +// scales the tables according to its quality parameter. +var unscaledQuant = [nQuantIndex][blockSize]byte{ + // Luminance. + { + 16, 11, 10, 16, 24, 40, 51, 61, + 12, 12, 14, 19, 26, 58, 60, 55, + 14, 13, 16, 24, 40, 57, 69, 56, + 14, 17, 22, 29, 51, 87, 80, 62, + 18, 22, 37, 56, 68, 109, 103, 77, + 24, 35, 55, 64, 81, 104, 113, 92, + 49, 64, 78, 87, 103, 121, 120, 101, + 72, 92, 95, 98, 112, 100, 103, 99, + }, + // Chrominance. + { + 17, 18, 24, 47, 99, 99, 99, 99, + 18, 21, 26, 66, 99, 99, 99, 99, + 24, 26, 56, 99, 99, 99, 99, 99, + 47, 66, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + }, +} + +type huffIndex int + +const ( + huffIndexLuminanceDC huffIndex = iota + huffIndexLuminanceAC + huffIndexChrominanceDC + huffIndexChrominanceAC + nHuffIndex +) + +// huffmanSpec specifies a Huffman encoding. +type huffmanSpec struct { + // count[i] is the number of codes of length i bits. + count [16]byte + // value[i] is the decoded value of the i'th codeword. + value []byte +} + +// theHuffmanSpec is the Huffman encoding specifications. +// This encoder uses the same Huffman encoding for all images. +var theHuffmanSpec = [nHuffIndex]huffmanSpec{ + // Luminance DC. + { + [16]byte{0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + // Luminance AC. + { + [16]byte{0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 125}, + []byte{ + 0x01, 0x02, 0x03, 0x00, 0x04, 0x11, 0x05, 0x12, + 0x21, 0x31, 0x41, 0x06, 0x13, 0x51, 0x61, 0x07, + 0x22, 0x71, 0x14, 0x32, 0x81, 0x91, 0xa1, 0x08, + 0x23, 0x42, 0xb1, 0xc1, 0x15, 0x52, 0xd1, 0xf0, + 0x24, 0x33, 0x62, 0x72, 0x82, 0x09, 0x0a, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, + 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, + 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, + 0x7a, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, + 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, + 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, + 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, + 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, 0xc4, 0xc5, + 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, 0xd3, 0xd4, + 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xe1, 0xe2, + 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, + 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa, + }, + }, + // Chrominance DC. + { + [16]byte{0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + // Chrominance AC. + { + [16]byte{0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 119}, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x11, 0x04, 0x05, 0x21, + 0x31, 0x06, 0x12, 0x41, 0x51, 0x07, 0x61, 0x71, + 0x13, 0x22, 0x32, 0x81, 0x08, 0x14, 0x42, 0x91, + 0xa1, 0xb1, 0xc1, 0x09, 0x23, 0x33, 0x52, 0xf0, + 0x15, 0x62, 0x72, 0xd1, 0x0a, 0x16, 0x24, 0x34, + 0xe1, 0x25, 0xf1, 0x17, 0x18, 0x19, 0x1a, 0x26, + 0x27, 0x28, 0x29, 0x2a, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, + 0x49, 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, + 0x59, 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, + 0x79, 0x7a, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, + 0x97, 0x98, 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, + 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, + 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, + 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, + 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, + 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, + 0xea, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa, + }, + }, +} + +// huffmanLUT is a compiled look-up table representation of a huffmanSpec. +// Each value maps to a uint32 of which the 8 most significant bits hold the +// codeword size in bits and the 24 least significant bits hold the codeword. +// The maximum codeword size is 16 bits. +type huffmanLUT []uint32 + +func (h *huffmanLUT) init(s huffmanSpec) { + maxValue := 0 + for _, v := range s.value { + if int(v) > maxValue { + maxValue = int(v) + } + } + *h = make([]uint32, maxValue+1) + code, k := uint32(0), 0 + for i := 0; i < len(s.count); i++ { + nBits := uint32(i+1) << 24 + for j := uint8(0); j < s.count[i]; j++ { + (*h)[s.value[k]] = nBits | code + code++ + k++ + } + code <<= 1 + } +} + +// theHuffmanLUT are compiled representations of theHuffmanSpec. +var theHuffmanLUT [4]huffmanLUT + +func init() { + for i, s := range theHuffmanSpec { + theHuffmanLUT[i].init(s) + } +} + +// writer is a buffered writer. +type writer interface { + Flush() os.Error + Write([]byte) (int, os.Error) + WriteByte(byte) os.Error +} + +// encoder encodes an image to the JPEG format. +type encoder struct { + // w is the writer to write to. err is the first error encountered during + // writing. All attempted writes after the first error become no-ops. + w writer + err os.Error + // buf is a scratch buffer. + buf [16]byte + // bits and nBits are accumulated bits to write to w. + bits uint32 + nBits uint8 + // quant is the scaled quantization tables. + quant [nQuantIndex][blockSize]byte +} + +func (e *encoder) flush() { + if e.err != nil { + return + } + e.err = e.w.Flush() +} + +func (e *encoder) write(p []byte) { + if e.err != nil { + return + } + _, e.err = e.w.Write(p) +} + +func (e *encoder) writeByte(b byte) { + if e.err != nil { + return + } + e.err = e.w.WriteByte(b) +} + +// emit emits the least significant nBits bits of bits to the bitstream. +// The precondition is bits < 1<<nBits && nBits <= 16. +func (e *encoder) emit(bits uint32, nBits uint8) { + nBits += e.nBits + bits <<= 32 - nBits + bits |= e.bits + for nBits >= 8 { + b := uint8(bits >> 24) + e.writeByte(b) + if b == 0xff { + e.writeByte(0x00) + } + bits <<= 8 + nBits -= 8 + } + e.bits, e.nBits = bits, nBits +} + +// emitHuff emits the given value with the given Huffman encoder. +func (e *encoder) emitHuff(h huffIndex, value int) { + x := theHuffmanLUT[h][value] + e.emit(x&(1<<24-1), uint8(x>>24)) +} + +// emitHuffRLE emits a run of runLength copies of value encoded with the given +// Huffman encoder. +func (e *encoder) emitHuffRLE(h huffIndex, runLength, value int) { + a, b := value, value + if a < 0 { + a, b = -value, value-1 + } + var nBits uint8 + if a < 0x100 { + nBits = bitCount[a] + } else { + nBits = 8 + bitCount[a>>8] + } + e.emitHuff(h, runLength<<4|int(nBits)) + if nBits > 0 { + e.emit(uint32(b)&(1<<nBits-1), nBits) + } +} + +// writeMarkerHeader writes the header for a marker with the given length. +func (e *encoder) writeMarkerHeader(marker uint8, markerlen int) { + e.buf[0] = 0xff + e.buf[1] = marker + e.buf[2] = uint8(markerlen >> 8) + e.buf[3] = uint8(markerlen & 0xff) + e.write(e.buf[:4]) +} + +// writeDQT writes the Define Quantization Table marker. +func (e *encoder) writeDQT() { + markerlen := 2 + for _, q := range e.quant { + markerlen += 1 + len(q) + } + e.writeMarkerHeader(dqtMarker, markerlen) + for i, q := range e.quant { + e.writeByte(uint8(i)) + e.write(q[:]) + } +} + +// writeSOF0 writes the Start Of Frame (Baseline) marker. +func (e *encoder) writeSOF0(size image.Point) { + markerlen := 8 + 3*nComponent + e.writeMarkerHeader(sof0Marker, markerlen) + e.buf[0] = 8 // 8-bit color. + e.buf[1] = uint8(size.Y >> 8) + e.buf[2] = uint8(size.Y & 0xff) + e.buf[3] = uint8(size.X >> 8) + e.buf[4] = uint8(size.X & 0xff) + e.buf[5] = nComponent + for i := 0; i < nComponent; i++ { + e.buf[3*i+6] = uint8(i + 1) + // We use 4:2:0 chroma subsampling. + e.buf[3*i+7] = "\x22\x11\x11"[i] + e.buf[3*i+8] = "\x00\x01\x01"[i] + } + e.write(e.buf[:3*(nComponent-1)+9]) +} + +// writeDHT writes the Define Huffman Table marker. +func (e *encoder) writeDHT() { + markerlen := 2 + for _, s := range theHuffmanSpec { + markerlen += 1 + 16 + len(s.value) + } + e.writeMarkerHeader(dhtMarker, markerlen) + for i, s := range theHuffmanSpec { + e.writeByte("\x00\x10\x01\x11"[i]) + e.write(s.count[:]) + e.write(s.value) + } +} + +// writeBlock writes a block of pixel data using the given quantization table, +// returning the post-quantized DC value of the DCT-transformed block. +func (e *encoder) writeBlock(b *block, q quantIndex, prevDC int) int { + fdct(b) + // Emit the DC delta. + dc := div(b[0], (8 * int(e.quant[q][0]))) + e.emitHuffRLE(huffIndex(2*q+0), 0, dc-prevDC) + // Emit the AC components. + h, runLength := huffIndex(2*q+1), 0 + for k := 1; k < blockSize; k++ { + ac := div(b[unzig[k]], (8 * int(e.quant[q][k]))) + if ac == 0 { + runLength++ + } else { + for runLength > 15 { + e.emitHuff(h, 0xf0) + runLength -= 16 + } + e.emitHuffRLE(h, runLength, ac) + runLength = 0 + } + } + if runLength > 0 { + e.emitHuff(h, 0x00) + } + return dc +} + +// toYCbCr converts the 8x8 region of m whose top-left corner is p to its +// YCbCr values. +func toYCbCr(m image.Image, p image.Point, yBlock, cbBlock, crBlock *block) { + b := m.Bounds() + xmax := b.Max.X - 1 + ymax := b.Max.Y - 1 + for j := 0; j < 8; j++ { + for i := 0; i < 8; i++ { + r, g, b, _ := m.At(min(p.X+i, xmax), min(p.Y+j, ymax)).RGBA() + yy, cb, cr := ycbcr.RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8)) + yBlock[8*j+i] = int(yy) + cbBlock[8*j+i] = int(cb) + crBlock[8*j+i] = int(cr) + } + } +} + +// scale scales the 16x16 region represented by the 4 src blocks to the 8x8 +// dst block. +func scale(dst *block, src *[4]block) { + for i := 0; i < 4; i++ { + dstOff := (i&2)<<4 | (i&1)<<2 + for y := 0; y < 4; y++ { + for x := 0; x < 4; x++ { + j := 16*y + 2*x + sum := src[i][j] + src[i][j+1] + src[i][j+8] + src[i][j+9] + dst[8*y+x+dstOff] = (sum + 2) >> 2 + } + } + } +} + +// sosHeader is the SOS marker "\xff\xda" followed by 12 bytes: +// - the marker length "\x00\x0c", +// - the number of components "\x03", +// - component 1 uses DC table 0 and AC table 0 "\x01\x00", +// - component 2 uses DC table 1 and AC table 1 "\x02\x11", +// - component 3 uses DC table 1 and AC table 1 "\x03\x11", +// - padding "\x00\x00\x00". +var sosHeader = []byte{ + 0xff, 0xda, 0x00, 0x0c, 0x03, 0x01, 0x00, 0x02, + 0x11, 0x03, 0x11, 0x00, 0x00, 0x00, +} + +// writeSOS writes the StartOfScan marker. +func (e *encoder) writeSOS(m image.Image) { + e.write(sosHeader) + var ( + // Scratch buffers to hold the YCbCr values. + yBlock block + cbBlock [4]block + crBlock [4]block + cBlock block + // DC components are delta-encoded. + prevDCY, prevDCCb, prevDCCr int + ) + bounds := m.Bounds() + for y := bounds.Min.Y; y < bounds.Max.Y; y += 16 { + for x := bounds.Min.X; x < bounds.Max.X; x += 16 { + for i := 0; i < 4; i++ { + xOff := (i & 1) * 8 + yOff := (i & 2) * 4 + p := image.Point{x + xOff, y + yOff} + toYCbCr(m, p, &yBlock, &cbBlock[i], &crBlock[i]) + prevDCY = e.writeBlock(&yBlock, 0, prevDCY) + } + scale(&cBlock, &cbBlock) + prevDCCb = e.writeBlock(&cBlock, 1, prevDCCb) + scale(&cBlock, &crBlock) + prevDCCr = e.writeBlock(&cBlock, 1, prevDCCr) + } + } + // Pad the last byte with 1's. + e.emit(0x7f, 7) +} + +// DefaultQuality is the default quality encoding parameter. +const DefaultQuality = 75 + +// Options are the encoding parameters. +// Quality ranges from 1 to 100 inclusive, higher is better. +type Options struct { + Quality int +} + +// Encode writes the Image m to w in JPEG 4:2:0 baseline format with the given +// options. Default parameters are used if a nil *Options is passed. +func Encode(w io.Writer, m image.Image, o *Options) os.Error { + b := m.Bounds() + if b.Dx() >= 1<<16 || b.Dy() >= 1<<16 { + return os.NewError("jpeg: image is too large to encode") + } + var e encoder + if ww, ok := w.(writer); ok { + e.w = ww + } else { + e.w = bufio.NewWriter(w) + } + // Clip quality to [1, 100]. + quality := DefaultQuality + if o != nil { + quality = o.Quality + if quality < 1 { + quality = 1 + } else if quality > 100 { + quality = 100 + } + } + // Convert from a quality rating to a scaling factor. + var scale int + if quality < 50 { + scale = 5000 / quality + } else { + scale = 200 - quality*2 + } + // Initialize the quantization tables. + for i := range e.quant { + for j := range e.quant[i] { + x := int(unscaledQuant[i][j]) + x = (x*scale + 50) / 100 + if x < 1 { + x = 1 + } else if x > 255 { + x = 255 + } + e.quant[i][j] = uint8(x) + } + } + // Write the Start Of Image marker. + e.buf[0] = 0xff + e.buf[1] = 0xd8 + e.write(e.buf[:2]) + // Write the quantization tables. + e.writeDQT() + // Write the image dimensions. + e.writeSOF0(b.Size()) + // Write the Huffman tables. + e.writeDHT() + // Write the image data. + e.writeSOS(m) + // Write the End Of Image marker. + e.buf[0] = 0xff + e.buf[1] = 0xd9 + e.write(e.buf[:2]) + e.flush() + return e.err +} diff --git a/src/pkg/image/jpeg/writer_test.go b/src/pkg/image/jpeg/writer_test.go new file mode 100644 index 000000000..00922dd5c --- /dev/null +++ b/src/pkg/image/jpeg/writer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "bytes" + "image" + "image/png" + "os" + "testing" +) + +var testCase = []struct { + filename string + quality int + tolerance int64 +}{ + {"../testdata/video-001.png", 1, 24 << 8}, + {"../testdata/video-001.png", 20, 12 << 8}, + {"../testdata/video-001.png", 60, 8 << 8}, + {"../testdata/video-001.png", 80, 6 << 8}, + {"../testdata/video-001.png", 90, 4 << 8}, + {"../testdata/video-001.png", 100, 2 << 8}, +} + +func delta(u0, u1 uint32) int64 { + d := int64(u0) - int64(u1) + if d < 0 { + return -d + } + return d +} + +func readPng(filename string) (image.Image, os.Error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return png.Decode(f) +} + +func TestWriter(t *testing.T) { + for _, tc := range testCase { + // Read the image. + m0, err := readPng(tc.filename) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Encode that image as JPEG. + buf := bytes.NewBuffer(nil) + err = Encode(buf, m0, &Options{Quality: tc.quality}) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Decode that JPEG. + m1, err := Decode(buf) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Compute the average delta in RGB space. + b := m0.Bounds() + var sum, n int64 + for y := b.Min.Y; y < b.Max.Y; y++ { + for x := b.Min.X; x < b.Max.X; x++ { + c0 := m0.At(x, y) + c1 := m1.At(x, y) + r0, g0, b0, _ := c0.RGBA() + r1, g1, b1, _ := c1.RGBA() + sum += delta(r0, r1) + sum += delta(g0, g1) + sum += delta(b0, b1) + n += 3 + } + } + // Compare the average delta to the tolerance level. + if sum/n > tc.tolerance { + t.Errorf("%s, quality=%d: average delta is too high", tc.filename, tc.quality) + continue + } + } +} diff --git a/src/pkg/image/png/reader.go b/src/pkg/image/png/reader.go index eee4eac2e..b30a951c1 100644 --- a/src/pkg/image/png/reader.go +++ b/src/pkg/image/png/reader.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The png package implements a PNG image decoder and encoder. +// Package png implements a PNG image decoder and encoder. // // The PNG specification is at http://www.libpng.org/pub/png/spec/1.2/PNG-Contents.html package png diff --git a/src/pkg/image/ycbcr/ycbcr.go b/src/pkg/image/ycbcr/ycbcr.go index b2e033b82..cda45996d 100644 --- a/src/pkg/image/ycbcr/ycbcr.go +++ b/src/pkg/image/ycbcr/ycbcr.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The ycbcr package provides images from the Y'CbCr color model. +// Package ycbcr provides images from the Y'CbCr color model. // // JPEG, VP8, the MPEG family and other codecs use this color model. Such // codecs often use the terms YUV and Y'CbCr interchangeably, but strictly diff --git a/src/pkg/index/suffixarray/suffixarray.go b/src/pkg/index/suffixarray/suffixarray.go index d8c6fc91b..079b7d8ed 100644 --- a/src/pkg/index/suffixarray/suffixarray.go +++ b/src/pkg/index/suffixarray/suffixarray.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The suffixarray package implements substring search in logarithmic time -// using an in-memory suffix array. +// Package suffixarray implements substring search in logarithmic time using +// an in-memory suffix array. // // Example use: // diff --git a/src/pkg/io/io.go b/src/pkg/io/io.go index d3707eb1d..0bc73d67d 100644 --- a/src/pkg/io/io.go +++ b/src/pkg/io/io.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides basic interfaces to I/O primitives. +// Package io provides basic interfaces to I/O primitives. // Its primary job is to wrap existing implementations of such primitives, // such as those in package os, into shared public interfaces that // abstract the functionality, plus some other related primitives. diff --git a/src/pkg/io/ioutil/ioutil.go b/src/pkg/io/ioutil/ioutil.go index 57d797e85..5f1eecaab 100644 --- a/src/pkg/io/ioutil/ioutil.go +++ b/src/pkg/io/ioutil/ioutil.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Utility functions. - +// Package ioutil implements some I/O utility functions. package ioutil import ( @@ -102,3 +101,13 @@ func (nopCloser) Close() os.Error { return nil } func NopCloser(r io.Reader) io.ReadCloser { return nopCloser{r} } + +type devNull int + +func (devNull) Write(p []byte) (int, os.Error) { + return len(p), nil +} + +// Discard is an io.Writer on which all Write calls succeed +// without doing anything. +var Discard io.Writer = devNull(0) diff --git a/src/pkg/json/decode.go b/src/pkg/json/decode.go index a5fd33912..e78b60ccb 100644 --- a/src/pkg/json/decode.go +++ b/src/pkg/json/decode.go @@ -122,11 +122,10 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) { } }() - rv := reflect.NewValue(v) + rv := reflect.ValueOf(v) pv := rv - if pv.Kind() != reflect.Ptr || - pv.IsNil() { - return &InvalidUnmarshalError{reflect.Typeof(v)} + if pv.Kind() != reflect.Ptr || pv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} } d.scan.reset() @@ -267,17 +266,17 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl v = iv.Elem() continue } + pv := v if pv.Kind() != reflect.Ptr { break } - if pv.Elem().Kind() != reflect.Ptr && - wantptr && !isUnmarshaler { + if pv.Elem().Kind() != reflect.Ptr && wantptr && pv.CanSet() && !isUnmarshaler { return nil, pv } if pv.IsNil() { - pv.Set(reflect.Zero(pv.Type().Elem()).Addr()) + pv.Set(reflect.New(pv.Type().Elem())) } if isUnmarshaler { // Using v.Interface().(Unmarshaler) @@ -314,7 +313,7 @@ func (d *decodeState) array(v reflect.Value) { iv := v ok := iv.Kind() == reflect.Interface if ok { - iv.Set(reflect.NewValue(d.arrayInterface())) + iv.Set(reflect.ValueOf(d.arrayInterface())) return } @@ -410,7 +409,7 @@ func (d *decodeState) object(v reflect.Value) { // Decoding into nil interface? Switch to non-reflect code. iv := v if iv.Kind() == reflect.Interface { - iv.Set(reflect.NewValue(d.objectInterface())) + iv.Set(reflect.ValueOf(d.objectInterface())) return } @@ -423,7 +422,7 @@ func (d *decodeState) object(v reflect.Value) { case reflect.Map: // map must have string type t := v.Type() - if t.Key() != reflect.Typeof("") { + if t.Key() != reflect.TypeOf("") { d.saveError(&UnmarshalTypeError{"object", v.Type()}) break } @@ -443,6 +442,8 @@ func (d *decodeState) object(v reflect.Value) { return } + var mapElem reflect.Value + for { // Read opening " of string key or closing }. op := d.scanWhile(scanSkipSpace) @@ -466,7 +467,13 @@ func (d *decodeState) object(v reflect.Value) { // Figure out field corresponding to key. var subv reflect.Value if mv.IsValid() { - subv = reflect.Zero(mv.Type().Elem()) + elemType := mv.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem } else { var f reflect.StructField var ok bool @@ -514,7 +521,7 @@ func (d *decodeState) object(v reflect.Value) { // Write value back to map; // if using struct, subv points into struct already. if mv.IsValid() { - mv.SetMapIndex(reflect.NewValue(key), subv) + mv.SetMapIndex(reflect.ValueOf(key), subv) } // Next token must be , or }. @@ -570,7 +577,7 @@ func (d *decodeState) literal(v reflect.Value) { case reflect.Bool: v.SetBool(value) case reflect.Interface: - v.Set(reflect.NewValue(value)) + v.Set(reflect.ValueOf(value)) } case '"': // string @@ -592,11 +599,11 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(err) break } - v.Set(reflect.NewValue(b[0:n])) + v.Set(reflect.ValueOf(b[0:n])) case reflect.String: v.SetString(string(s)) case reflect.Interface: - v.Set(reflect.NewValue(string(s))) + v.Set(reflect.ValueOf(string(s))) } default: // number @@ -613,7 +620,7 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(reflect.NewValue(n)) + v.Set(reflect.ValueOf(n)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.Atoi64(s) @@ -767,7 +774,7 @@ func (d *decodeState) literalInterface() interface{} { } n, err := strconv.Atof64(string(item)) if err != nil { - d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)}) + d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.TypeOf(0.0)}) } return n } diff --git a/src/pkg/json/decode_test.go b/src/pkg/json/decode_test.go index 49135c4bf..bf8bf10bf 100644 --- a/src/pkg/json/decode_test.go +++ b/src/pkg/json/decode_test.go @@ -21,7 +21,7 @@ type tx struct { x int } -var txType = reflect.Typeof((*tx)(nil)).Elem() +var txType = reflect.TypeOf((*tx)(nil)).Elem() // A type that can unmarshal itself. @@ -64,14 +64,14 @@ var unmarshalTests = []unmarshalTest{ {`"g-clef: \uD834\uDD1E"`, new(string), "g-clef: \U0001D11E", nil}, {`"invalid: \uD834x\uDD1E"`, new(string), "invalid: \uFFFDx\uFFFD", nil}, {"null", new(interface{}), nil, nil}, - {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.Typeof("")}}, + {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.TypeOf("")}}, {`{"x": 1}`, new(tx), tx{}, &UnmarshalFieldError{"x", txType, txType.Field(0)}}, // skip invalid tags {`{"X":"a", "y":"b", "Z":"c"}`, new(badTag), badTag{"a", "b", "c"}, nil}, // syntax errors - {`{"X": "foo", "Y"}`, nil, nil, SyntaxError("invalid character '}' after object key")}, + {`{"X": "foo", "Y"}`, nil, nil, &SyntaxError{"invalid character '}' after object key", 17}}, // composite tests {allValueIndent, new(All), allValue, nil}, @@ -125,12 +125,12 @@ func TestMarshalBadUTF8(t *testing.T) { } func TestUnmarshal(t *testing.T) { - var scan scanner for i, tt := range unmarshalTests { + var scan scanner in := []byte(tt.in) if err := checkValid(in, &scan); err != nil { if !reflect.DeepEqual(err, tt.err) { - t.Errorf("#%d: checkValid: %v", i, err) + t.Errorf("#%d: checkValid: %#v", i, err) continue } } @@ -138,8 +138,7 @@ func TestUnmarshal(t *testing.T) { continue } // v = new(right-type) - v := reflect.NewValue(tt.ptr) - v.Set(reflect.Zero(v.Type().Elem()).Addr()) + v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) { t.Errorf("#%d: %v want %v", i, err, tt.err) continue diff --git a/src/pkg/json/encode.go b/src/pkg/json/encode.go index dfa3c59da..ec0a14a6a 100644 --- a/src/pkg/json/encode.go +++ b/src/pkg/json/encode.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The json package implements encoding and decoding of JSON objects as -// defined in RFC 4627. +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. package json import ( @@ -172,7 +172,7 @@ func (e *encodeState) marshal(v interface{}) (err os.Error) { err = r.(os.Error) } }() - e.reflectValue(reflect.NewValue(v)) + e.reflectValue(reflect.ValueOf(v)) return nil } @@ -180,7 +180,7 @@ func (e *encodeState) error(err os.Error) { panic(err) } -var byteSliceType = reflect.Typeof([]byte(nil)) +var byteSliceType = reflect.TypeOf([]byte(nil)) func (e *encodeState) reflectValue(v reflect.Value) { if !v.IsValid() { diff --git a/src/pkg/json/scanner.go b/src/pkg/json/scanner.go index e98ddef5c..49c2edd54 100644 --- a/src/pkg/json/scanner.go +++ b/src/pkg/json/scanner.go @@ -23,6 +23,7 @@ import ( func checkValid(data []byte, scan *scanner) os.Error { scan.reset() for _, c := range data { + scan.bytes++ if scan.step(scan, int(c)) == scanError { return scan.err } @@ -56,10 +57,12 @@ func nextValue(data []byte, scan *scanner) (value, rest []byte, err os.Error) { } // A SyntaxError is a description of a JSON syntax error. -type SyntaxError string - -func (e SyntaxError) String() string { return string(e) } +type SyntaxError struct { + msg string // description of error + Offset int64 // error occurred after reading Offset bytes +} +func (e *SyntaxError) String() string { return e.msg } // A scanner is a JSON scanning state machine. // Callers call scan.reset() and then pass bytes in one at a time @@ -89,6 +92,9 @@ type scanner struct { // 1-byte redo (see undo method) redoCode int redoState func(*scanner, int) int + + // total bytes consumed, updated by decoder.Decode + bytes int64 } // These values are returned by the state transition functions @@ -148,7 +154,7 @@ func (s *scanner) eof() int { return scanEnd } if s.err == nil { - s.err = SyntaxError("unexpected end of JSON input") + s.err = &SyntaxError{"unexpected end of JSON input", s.bytes} } return scanError } @@ -581,7 +587,7 @@ func stateError(s *scanner, c int) int { // error records an error and switches to the error state. func (s *scanner) error(c int, context string) int { s.step = stateError - s.err = SyntaxError("invalid character " + quoteChar(c) + " " + context) + s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes} return scanError } diff --git a/src/pkg/json/stream.go b/src/pkg/json/stream.go index cb9b16559..f143b3f0a 100644 --- a/src/pkg/json/stream.go +++ b/src/pkg/json/stream.go @@ -23,8 +23,8 @@ func NewDecoder(r io.Reader) *Decoder { return &Decoder{r: r} } -// Decode reads the next JSON-encoded value from the -// connection and stores it in the value pointed to by v. +// Decode reads the next JSON-encoded value from its +// input and stores it in the value pointed to by v. // // See the documentation for Unmarshal for details about // the conversion of JSON into a Go value. @@ -62,6 +62,7 @@ Input: for { // Look in the buffer for a new value. for i, c := range dec.buf[scanp:] { + dec.scan.bytes++ v := dec.scan.step(&dec.scan, int(c)) if v == scanEnd { scanp += i diff --git a/src/pkg/log/log.go b/src/pkg/log/log.go index 33140ee08..00bce6a17 100644 --- a/src/pkg/log/log.go +++ b/src/pkg/log/log.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Simple logging package. It defines a type, Logger, with methods -// for formatting output. It also has a predefined 'standard' Logger -// accessible through helper functions Print[f|ln], Fatal[f|ln], and +// Package log implements a simple logging package. It defines a type, Logger, +// with methods for formatting output. It also has a predefined 'standard' +// Logger accessible through helper functions Print[f|ln], Fatal[f|ln], and // Panic[f|ln], which are easier to use than creating a Logger manually. // That logger writes to standard error and prints the date and time // of each logged message. diff --git a/src/pkg/math/const.go b/src/pkg/math/const.go index b53527a4f..a108d3e29 100644 --- a/src/pkg/math/const.go +++ b/src/pkg/math/const.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The math package provides basic constants and mathematical functions. +// Package math provides basic constants and mathematical functions. package math // Mathematical constants. diff --git a/src/pkg/mime/mediatype.go b/src/pkg/mime/mediatype.go index eb629aa6f..f28ff3e96 100644 --- a/src/pkg/mime/mediatype.go +++ b/src/pkg/mime/mediatype.go @@ -6,10 +6,30 @@ package mime import ( "bytes" + "fmt" + "os" "strings" "unicode" ) +func validMediaTypeOrDisposition(s string) bool { + typ, rest := consumeToken(s) + if typ == "" { + return false + } + if rest == "" { + return true + } + if !strings.HasPrefix(rest, "/") { + return false + } + subtype, rest := consumeToken(rest[1:]) + if subtype == "" { + return false + } + return rest == "" +} + // ParseMediaType parses a media type value and any optional // parameters, per RFC 1531. Media types are the values in // Content-Type and Content-Disposition headers (RFC 2183). On @@ -22,25 +42,112 @@ func ParseMediaType(v string) (mediatype string, params map[string]string) { i = len(v) } mediatype = strings.TrimSpace(strings.ToLower(v[0:i])) + if !validMediaTypeOrDisposition(mediatype) { + return "", nil + } + params = make(map[string]string) + // Map of base parameter name -> parameter name -> value + // for parameters containing a '*' character. + // Lazily initialized. + var continuation map[string]map[string]string + v = v[i:] for len(v) > 0 { v = strings.TrimLeftFunc(v, unicode.IsSpace) if len(v) == 0 { - return + break } key, value, rest := consumeMediaParam(v) if key == "" { + if strings.TrimSpace(rest) == ";" { + // Ignore trailing semicolons. + // Not an error. + return + } // Parse error. return "", nil } - params[key] = value + + pmap := params + if idx := strings.Index(key, "*"); idx != -1 { + baseName := key[:idx] + if continuation == nil { + continuation = make(map[string]map[string]string) + } + var ok bool + if pmap, ok = continuation[baseName]; !ok { + continuation[baseName] = make(map[string]string) + pmap = continuation[baseName] + } + } + if _, exists := pmap[key]; exists { + // Duplicate parameter name is bogus. + return "", nil + } + pmap[key] = value v = rest } + + // Stitch together any continuations or things with stars + // (i.e. RFC 2231 things with stars: "foo*0" or "foo*") + var buf bytes.Buffer + for key, pieceMap := range continuation { + singlePartKey := key + "*" + if v, ok := pieceMap[singlePartKey]; ok { + decv := decode2231Enc(v) + params[key] = decv + continue + } + + buf.Reset() + valid := false + for n := 0; ; n++ { + simplePart := fmt.Sprintf("%s*%d", key, n) + if v, ok := pieceMap[simplePart]; ok { + valid = true + buf.WriteString(v) + continue + } + encodedPart := simplePart + "*" + if v, ok := pieceMap[encodedPart]; ok { + valid = true + if n == 0 { + buf.WriteString(decode2231Enc(v)) + } else { + decv, _ := percentHexUnescape(v) + buf.WriteString(decv) + } + } else { + break + } + } + if valid { + params[key] = buf.String() + } + } + return } +func decode2231Enc(v string) string { + sv := strings.Split(v, "'", 3) + if len(sv) != 3 { + return "" + } + // TODO: ignoring lang in sv[1] for now. If anybody needs it we'll + // need to decide how to expose it in the API. But I'm not sure + // anybody uses it in practice. + charset := strings.ToLower(sv[0]) + if charset != "us-ascii" && charset != "utf-8" { + // TODO: unsupported encoding + return "" + } + encv, _ := percentHexUnescape(sv[2]) + return encv +} + func isNotTokenChar(rune int) bool { return !IsTokenChar(rune) } @@ -66,10 +173,12 @@ func consumeToken(v string) (token, rest string) { // quoted-string) and the rest of the string. On failure, returns // ("", v). func consumeValue(v string) (value, rest string) { - if !strings.HasPrefix(v, `"`) { + if !strings.HasPrefix(v, `"`) && !strings.HasPrefix(v, `'`) { return consumeToken(v) } + leadQuote := int(v[0]) + // parse a quoted-string rest = v[1:] // consume the leading quote buffer := new(bytes.Buffer) @@ -78,17 +187,14 @@ func consumeValue(v string) (value, rest string) { for idx, rune = range rest { switch { case nextIsLiteral: - if rune >= 0x80 { - return "", v - } buffer.WriteRune(rune) nextIsLiteral = false - case rune == '"': + case rune == leadQuote: return buffer.String(), rest[idx+1:] - case IsQText(rune): - buffer.WriteRune(rune) case rune == '\\': nextIsLiteral = true + case rune != '\r' && rune != '\n': + buffer.WriteRune(rune) default: return "", v } @@ -108,13 +214,79 @@ func consumeMediaParam(v string) (param, value, rest string) { if param == "" { return "", "", v } + + rest = strings.TrimLeftFunc(rest, unicode.IsSpace) if !strings.HasPrefix(rest, "=") { return "", "", v } rest = rest[1:] // consume equals sign + rest = strings.TrimLeftFunc(rest, unicode.IsSpace) value, rest = consumeValue(rest) if value == "" { return "", "", v } return param, value, rest } + +func percentHexUnescape(s string) (string, os.Error) { + // Count %, check that they're well-formed. + percents := 0 + for i := 0; i < len(s); { + if s[i] != '%' { + i++ + continue + } + percents++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[0:3] + } + return "", fmt.Errorf("mime: bogus characters after %%: %q", s) + } + i += 3 + } + if percents == 0 { + return s, nil + } + + t := make([]byte, len(s)-2*percents) + j := 0 + for i := 0; i < len(s); { + switch s[i] { + case '%': + t[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + j++ + i += 3 + default: + t[j] = s[i] + j++ + i++ + } + } + return string(t), nil +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} diff --git a/src/pkg/mime/mediatype_test.go b/src/pkg/mime/mediatype_test.go index 4891e899d..454ddd037 100644 --- a/src/pkg/mime/mediatype_test.go +++ b/src/pkg/mime/mediatype_test.go @@ -5,6 +5,7 @@ package mime import ( + "reflect" "testing" ) @@ -85,23 +86,152 @@ func TestConsumeMediaParam(t *testing.T) { } } +type mediaTypeTest struct { + in string + t string + p map[string]string +} + func TestParseMediaType(t *testing.T) { - tests := [...]string{ - `form-data; name="foo"`, - ` form-data ; name=foo`, - `FORM-DATA;name="foo"`, - ` FORM-DATA ; name="foo"`, - ` FORM-DATA ; name="foo"`, - `form-data; key=value; blah="value";name="foo" `, + // Convenience map initializer + m := func(s ...string) map[string]string { + sm := make(map[string]string) + for i := 0; i < len(s); i += 2 { + sm[s[i]] = s[i+1] + } + return sm + } + + nameFoo := map[string]string{"name": "foo"} + tests := []mediaTypeTest{ + {`form-data; name="foo"`, "form-data", nameFoo}, + {` form-data ; name=foo`, "form-data", nameFoo}, + {`FORM-DATA;name="foo"`, "form-data", nameFoo}, + {` FORM-DATA ; name="foo"`, "form-data", nameFoo}, + {` FORM-DATA ; name="foo"`, "form-data", nameFoo}, + + {`form-data; key=value; blah="value";name="foo" `, + "form-data", + m("key", "value", "blah", "value", "name", "foo")}, + + {`foo; key=val1; key=the-key-appears-again-which-is-bogus`, + "", m()}, + + // From RFC 2231: + {`application/x-stuff; title*=us-ascii'en-us'This%20is%20%2A%2A%2Afun%2A%2A%2A`, + "application/x-stuff", + m("title", "This is ***fun***")}, + + {`message/external-body; access-type=URL; ` + + `URL*0="ftp://";` + + `URL*1="cs.utk.edu/pub/moore/bulk-mailer/bulk-mailer.tar"`, + "message/external-body", + m("access-type", "URL", + "URL", "ftp://cs.utk.edu/pub/moore/bulk-mailer/bulk-mailer.tar")}, + + {`application/x-stuff; ` + + `title*0*=us-ascii'en'This%20is%20even%20more%20; ` + + `title*1*=%2A%2A%2Afun%2A%2A%2A%20; ` + + `title*2="isn't it!"`, + "application/x-stuff", + m("title", "This is even more ***fun*** isn't it!")}, + + // Tests from http://greenbytes.de/tech/tc2231/ + // TODO(bradfitz): add the rest of the tests from that site. + {`attachment; filename="f\oo.html"`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename="\"quoting\" tested.html"`, + "attachment", + m("filename", `"quoting" tested.html`)}, + {`attachment; filename="Here's a semicolon;.html"`, + "attachment", + m("filename", "Here's a semicolon;.html")}, + {`attachment; foo="\"\\";filename="foo.html"`, + "attachment", + m("foo", "\"\\", "filename", "foo.html")}, + {`attachment; filename=foo.html`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename=foo.html ;`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename='foo.html'`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename="foo-%41.html"`, + "attachment", + m("filename", "foo-%41.html")}, + {`attachment; filename="foo-%\41.html"`, + "attachment", + m("filename", "foo-%41.html")}, + {`filename=foo.html`, + "", m()}, + {`x=y; filename=foo.html`, + "", m()}, + {`"foo; filename=bar;baz"; filename=qux`, + "", m()}, + {`inline; attachment; filename=foo.html`, + "", m()}, + {`attachment; filename="foo.html".txt`, + "", m()}, + {`attachment; filename="bar`, + "", m()}, + {`attachment; creation-date="Wed, 12 Feb 1997 16:29:51 -0500"`, + "attachment", + m("creation-date", "Wed, 12 Feb 1997 16:29:51 -0500")}, + {`foobar`, "foobar", m()}, + {`attachment; filename* =UTF-8''foo-%c3%a4.html`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*=UTF-8''A-%2541.html`, + "attachment", + m("filename", "A-%41.html")}, + {`attachment; filename*0="foo."; filename*1="html"`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename*0*=UTF-8''foo-%c3%a4; filename*1=".html"`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*0="foo"; filename*01="bar"`, + "attachment", + m("filename", "foo")}, + {`attachment; filename*0="foo"; filename*2="bar"`, + "attachment", + m("filename", "foo")}, + {`attachment; filename*1="foo"; filename*2="bar"`, + "attachment", m()}, + {`attachment; filename*1="bar"; filename*0="foo"`, + "attachment", + m("filename", "foobar")}, + {`attachment; filename="foo-ae.html"; filename*=UTF-8''foo-%c3%a4.html`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*=UTF-8''foo-%c3%a4.html; filename="foo-ae.html"`, + "attachment", + m("filename", "foo-ä.html")}, + + // Browsers also just send UTF-8 directly without RFC 2231, + // at least when the source page is served with UTF-8. + {`form-data; firstname="Брэд"; lastname="Фицпатрик"`, + "form-data", + m("firstname", "Брэд", "lastname", "Фицпатрик")}, } for _, test := range tests { - mt, params := ParseMediaType(test) - if mt != "form-data" { - t.Errorf("expected type form-data for %s, got [%s]", test, mt) + mt, params := ParseMediaType(test.in) + if g, e := mt, test.t; g != e { + t.Errorf("for input %q, expected type %q, got %q", + test.in, e, g) + continue + } + if len(params) == 0 && len(test.p) == 0 { continue } - if params["name"] != "foo" { - t.Errorf("expected name=foo for %s", test) + if !reflect.DeepEqual(params, test.p) { + t.Errorf("for input %q, wrong params.\n"+ + "expected: %#v\n"+ + " got: %#v", + test.in, test.p, params) } } } diff --git a/src/pkg/mime/multipart/Makefile b/src/pkg/mime/multipart/Makefile index 5a7b98d03..5051f0df1 100644 --- a/src/pkg/mime/multipart/Makefile +++ b/src/pkg/mime/multipart/Makefile @@ -6,6 +6,7 @@ include ../../../Make.inc TARG=mime/multipart GOFILES=\ + formdata.go\ multipart.go\ include ../../../Make.pkg diff --git a/src/pkg/mime/multipart/formdata.go b/src/pkg/mime/multipart/formdata.go new file mode 100644 index 000000000..287938557 --- /dev/null +++ b/src/pkg/mime/multipart/formdata.go @@ -0,0 +1,169 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package multipart + +import ( + "bytes" + "io" + "io/ioutil" + "net/textproto" + "os" +) + +// TODO(adg,bradfitz): find a way to unify the DoS-prevention strategy here +// with that of the http package's ParseForm. + +// ReadForm parses an entire multipart message whose parts have +// a Content-Disposition of "form-data". +// It stores up to maxMemory bytes of the file parts in memory +// and the remainder on disk in temporary files. +func (r *multiReader) ReadForm(maxMemory int64) (f *Form, err os.Error) { + form := &Form{make(map[string][]string), make(map[string][]*FileHeader)} + defer func() { + if err != nil { + form.RemoveAll() + } + }() + + maxValueBytes := int64(10 << 20) // 10 MB is a lot of text. + for { + p, err := r.NextPart() + if err != nil { + return nil, err + } + if p == nil { + break + } + + name := p.FormName() + if name == "" { + continue + } + var filename string + if p.dispositionParams != nil { + filename = p.dispositionParams["filename"] + } + + var b bytes.Buffer + + if filename == "" { + // value, store as string in memory + n, err := io.Copyn(&b, p, maxValueBytes) + if err != nil && err != os.EOF { + return nil, err + } + maxValueBytes -= n + if maxValueBytes == 0 { + return nil, os.NewError("multipart: message too large") + } + form.Value[name] = append(form.Value[name], b.String()) + continue + } + + // file, store in memory or on disk + fh := &FileHeader{ + Filename: filename, + Header: p.Header, + } + n, err := io.Copyn(&b, p, maxMemory+1) + if err != nil && err != os.EOF { + return nil, err + } + if n > maxMemory { + // too big, write to disk and flush buffer + file, err := ioutil.TempFile("", "multipart-") + if err != nil { + return nil, err + } + defer file.Close() + _, err = io.Copy(file, io.MultiReader(&b, p)) + if err != nil { + os.Remove(file.Name()) + return nil, err + } + fh.tmpfile = file.Name() + } else { + fh.content = b.Bytes() + maxMemory -= n + } + form.File[name] = append(form.File[name], fh) + } + + return form, nil +} + +// Form is a parsed multipart form. +// Its File parts are stored either in memory or on disk, +// and are accessible via the *FileHeader's Open method. +// Its Value parts are stored as strings. +// Both are keyed by field name. +type Form struct { + Value map[string][]string + File map[string][]*FileHeader +} + +// RemoveAll removes any temporary files associated with a Form. +func (f *Form) RemoveAll() os.Error { + var err os.Error + for _, fhs := range f.File { + for _, fh := range fhs { + if fh.tmpfile != "" { + e := os.Remove(fh.tmpfile) + if e != nil && err == nil { + err = e + } + } + } + } + return err +} + +// A FileHeader describes a file part of a multipart request. +type FileHeader struct { + Filename string + Header textproto.MIMEHeader + + content []byte + tmpfile string +} + +// Open opens and returns the FileHeader's associated File. +func (fh *FileHeader) Open() (File, os.Error) { + if b := fh.content; b != nil { + r := io.NewSectionReader(sliceReaderAt(b), 0, int64(len(b))) + return sectionReadCloser{r}, nil + } + return os.Open(fh.tmpfile) +} + +// File is an interface to access the file part of a multipart message. +// Its contents may be either stored in memory or on disk. +// If stored on disk, the File's underlying concrete type will be an *os.File. +type File interface { + io.Reader + io.ReaderAt + io.Seeker + io.Closer +} + +// helper types to turn a []byte into a File + +type sectionReadCloser struct { + *io.SectionReader +} + +func (rc sectionReadCloser) Close() os.Error { + return nil +} + +type sliceReaderAt []byte + +func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, os.Error) { + if int(off) >= len(r) || off < 0 { + return 0, os.EINVAL + } + n := copy(b, r[int(off):]) + return n, nil +} diff --git a/src/pkg/mime/multipart/formdata_test.go b/src/pkg/mime/multipart/formdata_test.go new file mode 100644 index 000000000..b56e2a430 --- /dev/null +++ b/src/pkg/mime/multipart/formdata_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package multipart + +import ( + "bytes" + "io" + "os" + "regexp" + "testing" +) + +func TestReadForm(t *testing.T) { + testBody := regexp.MustCompile("\n").ReplaceAllString(message, "\r\n") + b := bytes.NewBufferString(testBody) + r := NewReader(b, boundary) + f, err := r.ReadForm(25) + if err != nil { + t.Fatal("ReadForm:", err) + } + defer f.RemoveAll() + if g, e := f.Value["texta"][0], textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := f.Value["textb"][0], textbValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + fd := testFile(t, f.File["filea"][0], "filea.txt", fileaContents) + if _, ok := fd.(*os.File); ok { + t.Error("file is *os.File, should not be") + } + fd = testFile(t, f.File["fileb"][0], "fileb.txt", filebContents) + if _, ok := fd.(*os.File); !ok { + t.Error("file has unexpected underlying type %T", fd) + } +} + +func testFile(t *testing.T, fh *FileHeader, efn, econtent string) File { + if fh.Filename != efn { + t.Errorf("filename = %q, want %q", fh.Filename, efn) + } + f, err := fh.Open() + if err != nil { + t.Fatal("opening file:", err) + } + b := new(bytes.Buffer) + _, err = io.Copy(b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != econtent { + t.Errorf("contents = %q, want %q", g, econtent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/src/pkg/mime/multipart/multipart.go b/src/pkg/mime/multipart/multipart.go index 0a65a447d..e0b747c3f 100644 --- a/src/pkg/mime/multipart/multipart.go +++ b/src/pkg/mime/multipart/multipart.go @@ -16,6 +16,7 @@ import ( "bufio" "bytes" "io" + "io/ioutil" "mime" "net/textproto" "os" @@ -34,6 +35,12 @@ type Reader interface { // reports errors, or on truncated or otherwise malformed // input. NextPart() (*Part, os.Error) + + // ReadForm parses an entire multipart message whose parts have + // a Content-Disposition of "form-data". + // It stores up to maxMemory bytes of the file parts in memory + // and the remainder on disk in temporary files. + ReadForm(maxMemory int64) (*Form, os.Error) } // A Part represents a single part in a multipart body. @@ -45,6 +52,8 @@ type Part struct { buffer *bytes.Buffer mr *multiReader + + dispositionParams map[string]string } // FormName returns the name parameter if p has a Content-Disposition @@ -52,15 +61,19 @@ type Part struct { func (p *Part) FormName() string { // See http://tools.ietf.org/html/rfc2183 section 2 for EBNF // of Content-Disposition value format. + if p.dispositionParams != nil { + return p.dispositionParams["name"] + } v := p.Header.Get("Content-Disposition") if v == "" { return "" } - d, params := mime.ParseMediaType(v) - if d != "form-data" { + if d, params := mime.ParseMediaType(v); d != "form-data" { return "" + } else { + p.dispositionParams = params } - return params["name"] + return p.dispositionParams["name"] } // NewReader creates a new multipart Reader reading from r using the @@ -76,14 +89,6 @@ func NewReader(reader io.Reader, boundary string) Reader { // Implementation .... -type devNullWriter bool - -func (*devNullWriter) Write(p []byte) (n int, err os.Error) { - return len(p), nil -} - -var devNull = devNullWriter(false) - func newPart(mr *multiReader) (bp *Part, err os.Error) { bp = new(Part) bp.Header = make(map[string][]string) @@ -97,10 +102,11 @@ func newPart(mr *multiReader) (bp *Part, err os.Error) { func (bp *Part) populateHeaders() os.Error { for { - line, err := bp.mr.bufReader.ReadString('\n') + lineBytes, err := bp.mr.bufReader.ReadSlice('\n') if err != nil { return err } + line := string(lineBytes) if line == "\n" || line == "\r\n" { return nil } @@ -157,7 +163,7 @@ func (bp *Part) Read(p []byte) (n int, err os.Error) { } func (bp *Part) Close() os.Error { - io.Copy(&devNull, bp) + io.Copy(ioutil.Discard, bp) return nil } @@ -179,11 +185,12 @@ func (mr *multiReader) eof() bool { } func (mr *multiReader) readLine() bool { - line, err := mr.bufReader.ReadString('\n') + lineBytes, err := mr.bufReader.ReadSlice('\n') if err != nil { // TODO: care about err being EOF or not? return false } + line := string(lineBytes) mr.bufferedLine = &line return true } diff --git a/src/pkg/mime/multipart/multipart_test.go b/src/pkg/mime/multipart/multipart_test.go index 1f3d32d7e..f8f10f3e1 100644 --- a/src/pkg/mime/multipart/multipart_test.go +++ b/src/pkg/mime/multipart/multipart_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "json" + "os" "regexp" "strings" "testing" @@ -205,3 +206,34 @@ func TestVariousTextLineEndings(t *testing.T) { } } + +type maliciousReader struct { + t *testing.T + n int +} + +const maxReadThreshold = 1 << 20 + +func (mr *maliciousReader) Read(b []byte) (n int, err os.Error) { + mr.n += len(b) + if mr.n >= maxReadThreshold { + mr.t.Fatal("too much was read") + return 0, os.EOF + } + return len(b), nil +} + +func TestLineLimit(t *testing.T) { + mr := &maliciousReader{t: t} + r := NewReader(mr, "fooBoundary") + part, err := r.NextPart() + if part != nil { + t.Errorf("unexpected part read") + } + if err == nil { + t.Errorf("expected an error") + } + if mr.n >= maxReadThreshold { + t.Errorf("expected to read < %d bytes; read %d", maxReadThreshold, mr.n) + } +} diff --git a/src/pkg/mime/type.go b/src/pkg/mime/type.go index 6fe0ed5fd..8c43b81b0 100644 --- a/src/pkg/mime/type.go +++ b/src/pkg/mime/type.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The mime package implements parts of the MIME spec. +// Package mime implements parts of the MIME spec. package mime import ( diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile index 7ce650279..221871cb1 100644 --- a/src/pkg/net/Makefile +++ b/src/pkg/net/Makefile @@ -6,7 +6,6 @@ include ../../Make.inc TARG=net GOFILES=\ - cgo_stub.go\ dial.go\ dnsmsg.go\ fd_$(GOOS).go\ @@ -31,6 +30,10 @@ GOFILES_freebsd=\ dnsclient.go\ port.go\ +CGOFILES_freebsd=\ + cgo_bsd.go\ + cgo_unix.go\ + GOFILES_darwin=\ newpollserver.go\ fd.go\ @@ -38,6 +41,10 @@ GOFILES_darwin=\ dnsconfig.go\ dnsclient.go\ port.go\ + +CGOFILES_darwin=\ + cgo_bsd.go\ + cgo_unix.go\ GOFILES_linux=\ newpollserver.go\ @@ -47,10 +54,23 @@ GOFILES_linux=\ dnsclient.go\ port.go\ +ifeq ($(GOARCH),arm) +# ARM has no cgo, so use the stubs. +GOFILES_linux+=cgo_stub.go +else +CGOFILES_linux=\ + cgo_linux.go\ + cgo_unix.go +endif + GOFILES_windows=\ + cgo_stub.go\ resolv_windows.go\ file_windows.go\ GOFILES+=$(GOFILES_$(GOOS)) +ifneq ($(CGOFILES_$(GOOS)),) +CGOFILES+=$(CGOFILES_$(GOOS)) +endif include ../../Make.pkg diff --git a/src/pkg/net/cgo_bsd.go b/src/pkg/net/cgo_bsd.go new file mode 100644 index 000000000..4984df4a2 --- /dev/null +++ b/src/pkg/net/cgo_bsd.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoMask() C.int { + return C.AI_MASK +} diff --git a/src/pkg/net/cgo_linux.go b/src/pkg/net/cgo_linux.go new file mode 100644 index 000000000..8d4413d2d --- /dev/null +++ b/src/pkg/net/cgo_linux.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoMask() C.int { + return C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL +} diff --git a/src/pkg/net/cgo_stub.go b/src/pkg/net/cgo_stub.go index e28f6622e..c6277cb65 100644 --- a/src/pkg/net/cgo_stub.go +++ b/src/pkg/net/cgo_stub.go @@ -19,3 +19,7 @@ func cgoLookupPort(network, service string) (port int, err os.Error, completed b func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { return nil, nil, false } + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + return "", nil, false +} diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go new file mode 100644 index 000000000..a3711d601 --- /dev/null +++ b/src/pkg/net/cgo_unix.go @@ -0,0 +1,148 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> +#include <stdlib.h> +#include <unistd.h> +#include <string.h> +*/ +import "C" + +import ( + "os" + "syscall" + "unsafe" +) + +func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { + ip, err, completed := cgoLookupIP(name) + for _, p := range ip { + addrs = append(addrs, p.String()) + } + return +} + +func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) { + var res *C.struct_addrinfo + var hints C.struct_addrinfo + + switch net { + case "": + // no hints + case "tcp", "tcp4", "tcp6": + hints.ai_socktype = C.SOCK_STREAM + hints.ai_protocol = C.IPPROTO_TCP + case "udp", "udp4", "udp6": + hints.ai_socktype = C.SOCK_DGRAM + hints.ai_protocol = C.IPPROTO_UDP + default: + return 0, UnknownNetworkError(net), true + } + if len(net) >= 4 { + switch net[3] { + case '4': + hints.ai_family = C.AF_INET + case '6': + hints.ai_family = C.AF_INET6 + } + } + + s := C.CString(service) + defer C.free(unsafe.Pointer(s)) + if C.getaddrinfo(nil, s, &hints, &res) == 0 { + defer C.freeaddrinfo(res) + for r := res; r != nil; r = r.ai_next { + switch r.ai_family { + default: + continue + case C.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + case C.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + } + } + } + return 0, &AddrError{"unknown port", net + "/" + service}, true +} + +func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, completed bool) { + var res *C.struct_addrinfo + var hints C.struct_addrinfo + + // NOTE(rsc): In theory there are approximately balanced + // arguments for and against including AI_ADDRCONFIG + // in the flags (it includes IPv4 results only on IPv4 systems, + // and similarly for IPv6), but in practice setting it causes + // getaddrinfo to return the wrong canonical name on Linux. + // So definitely leave it out. + hints.ai_flags = (C.AI_ALL | C.AI_V4MAPPED | C.AI_CANONNAME) & cgoAddrInfoMask() + + h := C.CString(name) + defer C.free(unsafe.Pointer(h)) + gerrno, err := C.getaddrinfo(h, nil, &hints, &res) + if gerrno != 0 { + var str string + if gerrno == C.EAI_NONAME { + str = noSuchHost + } else if gerrno == C.EAI_SYSTEM { + str = err.String() + } else { + str = C.GoString(C.gai_strerror(gerrno)) + } + return nil, "", &DNSError{Error: str, Name: name}, true + } + defer C.freeaddrinfo(res) + if res != nil { + cname = C.GoString(res.ai_canonname) + if cname == "" { + cname = name + } + if len(cname) > 0 && cname[len(cname)-1] != '.' { + cname += "." + } + } + for r := res; r != nil; r = r.ai_next { + // Everything comes back twice, once for UDP and once for TCP. + if r.ai_socktype != C.SOCK_STREAM { + continue + } + switch r.ai_family { + default: + continue + case C.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + case C.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + } + } + return addrs, cname, nil, true +} + +func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { + addrs, _, err, completed = cgoLookupIPCNAME(name) + return +} + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + _, cname, err, completed = cgoLookupIPCNAME(name) + return +} + +func copyIP(x IP) IP { + y := make(IP, len(x)) + copy(y, x) + return y +} diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 66cb09b19..16896b426 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -30,7 +30,7 @@ func Dial(net, addr string) (c Conn, err os.Error) { switch net { case "tcp", "tcp4", "tcp6": var ra *TCPAddr - if ra, err = ResolveTCPAddr(raddr); err != nil { + if ra, err = ResolveTCPAddr(net, raddr); err != nil { goto Error } c, err := DialTCP(net, nil, ra) @@ -40,7 +40,7 @@ func Dial(net, addr string) (c Conn, err os.Error) { return c, nil case "udp", "udp4", "udp6": var ra *UDPAddr - if ra, err = ResolveUDPAddr(raddr); err != nil { + if ra, err = ResolveUDPAddr(net, raddr); err != nil { goto Error } c, err := DialUDP(net, nil, ra) @@ -83,7 +83,7 @@ func Listen(net, laddr string) (l Listener, err os.Error) { case "tcp", "tcp4", "tcp6": var la *TCPAddr if laddr != "" { - if la, err = ResolveTCPAddr(laddr); err != nil { + if la, err = ResolveTCPAddr(net, laddr); err != nil { return nil, err } } @@ -116,7 +116,7 @@ func ListenPacket(net, laddr string) (c PacketConn, err os.Error) { case "udp", "udp4", "udp6": var la *UDPAddr if laddr != "" { - if la, err = ResolveUDPAddr(laddr); err != nil { + if la, err = ResolveUDPAddr(net, laddr); err != nil { return nil, err } } diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index 9a9c02ebd..c25089ba4 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -56,29 +56,44 @@ var googleaddrs = []string{ } func TestLookupCNAME(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } cname, err := LookupCNAME("www.google.com") - if cname != "www.l.google.com." || err != nil { - t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "www.l.google.com.", nil`, cname, err) + if !strings.HasSuffix(cname, ".l.google.com.") || err != nil { + t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "*.l.google.com.", nil`, cname, err) } } func TestDialGoogle(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } // If no ipv6 tunnel, don't try the last address. if !*ipv6 { googleaddrs[len(googleaddrs)-1] = "" } - // Insert an actual IP address for google.com + // Insert an actual IPv4 address for google.com // into the table. - addrs, err := LookupIP("www.google.com") if err != nil { t.Fatalf("lookup www.google.com: %v", err) } - if len(addrs) == 0 { - t.Fatalf("no addresses for www.google.com") + var ip IP + for _, addr := range addrs { + if x := addr.To4(); x != nil { + ip = x + break + } + } + if ip == nil { + t.Fatalf("no IPv4 addresses for www.google.com") } - ip := addrs[0].To4() for i, s := range googleaddrs { if strings.Contains(s, "%") { diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index c3e727bce..89f2409bf 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -307,17 +307,22 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro } // goLookupHost is the native Go implementation of LookupHost. +// Used only if cgoLookupHost refuses to handle the request +// (that is, only if cgoLookupHost is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. func goLookupHost(name string) (addrs []string, err os.Error) { - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr - return - } // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) if len(addrs) > 0 { return } + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } ips, err := goLookupIP(name) if err != nil { return @@ -330,6 +335,11 @@ func goLookupHost(name string) (addrs []string, err os.Error) { } // goLookupIP is the native Go implementation of LookupIP. +// Used only if cgoLookupIP refuses to handle the request +// (that is, only if cgoLookupIP is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. func goLookupIP(name string) (addrs []IP, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { @@ -358,11 +368,13 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { return } -// LookupCNAME returns the canonical DNS host for the given name. -// Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving -// the canonical name as part of the lookup. -func LookupCNAME(name string) (cname string, err os.Error) { +// goLookupCNAME is the native Go implementation of LookupCNAME. +// Used only if cgoLookupCNAME refuses to handle the request +// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func goLookupCNAME(name string) (cname string, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go index e8eb8d958..7b8e5c6d3 100644 --- a/src/pkg/net/dnsmsg.go +++ b/src/pkg/net/dnsmsg.go @@ -426,7 +426,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) if off+n > len(msg) { return len(msg), false } - reflect.Copy(reflect.NewValue(msg[off:off+n]), fv) + reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv) off += n case reflect.String: // There are multiple string encodings. @@ -456,7 +456,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) } func structValue(any interface{}) reflect.Value { - return reflect.NewValue(any).Elem() + return reflect.ValueOf(any).Elem() } func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { @@ -499,7 +499,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo if off+n > len(msg) { return len(msg), false } - reflect.Copy(fv, reflect.NewValue(msg[off:off+n])) + reflect.Copy(fv, reflect.ValueOf(msg[off:off+n])) off += n case reflect.String: var s string diff --git a/src/pkg/net/hosts_test.go b/src/pkg/net/hosts_test.go index 470e35f78..e5793eef2 100644 --- a/src/pkg/net/hosts_test.go +++ b/src/pkg/net/hosts_test.go @@ -5,6 +5,7 @@ package net import ( + "sort" "testing" ) @@ -51,3 +52,17 @@ func TestLookupStaticHost(t *testing.T) { } hostsPath = p } + +func TestLookupHost(t *testing.T) { + // Can't depend on this to return anything in particular, + // but if it does return something, make sure it doesn't + // duplicate addresses (a common bug due to the way + // getaddrinfo works). + addrs, _ := LookupHost("localhost") + sort.SortStrings(addrs) + for i := 0; i+1 < len(addrs); i++ { + if addrs[i] == addrs[i+1] { + t.Fatalf("LookupHost(\"localhost\") = %v, has duplicate addresses", addrs) + } + } +} diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index 12bb6f351..61b2c687e 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -75,7 +75,8 @@ var ( // Well-known IPv6 addresses var ( - IPzero = make(IP, IPv6len) // all zeros + IPzero = make(IP, IPv6len) // all zeros + IPv6loopback = IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) ) // Is p all zeros? @@ -436,7 +437,7 @@ func parseIPv6(s string) IP { } // Otherwise must be followed by colon and more. - if s[i] != ':' && i+1 == len(s) { + if s[i] != ':' || i+1 == len(s) { return nil } i++ diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index f1a4716d2..2008953ef 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -29,6 +29,7 @@ var parseiptests = []struct { {"127.0.0.1", IPv4(127, 0, 0, 1)}, {"127.0.0.256", nil}, {"abc", nil}, + {"123:", nil}, {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index 60433303a..5be6fe4e0 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -245,7 +245,7 @@ func hostToIP(host string) (ip IP, err os.Error) { err = err1 goto Error } - addr = firstSupportedAddr(addrs) + addr = firstSupportedAddr(anyaddr, addrs) if addr == nil { // should not happen err = &AddrError{"LookupHost returned invalid address", addrs[0]} diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index 80bc3eea5..e8bcac646 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -35,15 +35,28 @@ func kernelSupportsIPv6() bool { var preferIPv4 = !kernelSupportsIPv6() -func firstSupportedAddr(addrs []string) (addr IP) { +func firstSupportedAddr(filter func(IP) IP, addrs []string) IP { for _, s := range addrs { - addr = ParseIP(s) - if !preferIPv4 || addr.To4() != nil { - break + if addr := filter(ParseIP(s)); addr != nil { + return addr } - addr = nil } - return addr + return nil +} + +func anyaddr(x IP) IP { return x } +func ipv4only(x IP) IP { return x.To4() } + +func ipv6only(x IP) IP { + // Only return addresses that we can use + // with the kernel's IPv6 addressing modes. + // If preferIPv4 is set, it means the IPv6 stack + // cannot take IPv4 addresses directly (we prefer + // to use the IPv4 stack) so reject IPv4 addresses. + if x.To4() != nil && preferIPv4 { + return nil + } + return x } // TODO(rsc): if syscall.OS == "linux", we're supposd to read @@ -131,7 +144,6 @@ func (e InvalidAddrError) String() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } - func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { switch family { case syscall.AF_INET: @@ -218,13 +230,31 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { // Try as an IP address. addr = ParseIP(host) if addr == nil { + filter := anyaddr + if len(net) >= 4 && net[3] == '4' { + filter = ipv4only + } else if len(net) >= 4 && net[3] == '6' { + filter = ipv6only + } // Not an IP address. Try as a DNS name. addrs, err1 := LookupHost(host) if err1 != nil { err = err1 goto Error } - addr = firstSupportedAddr(addrs) + if filter == anyaddr { + // We'll take any IP address, but since the dialing code + // does not yet try multiple addresses, prefer to use + // an IPv4 address if possible. This is especially relevant + // if localhost resolves to [ipv6-localhost, ipv4-localhost]. + // Too much code assumes localhost == ipv4-localhost. + addr = firstSupportedAddr(ipv4only, addrs) + if addr == nil { + addr = firstSupportedAddr(anyaddr, addrs) + } + } else { + addr = firstSupportedAddr(filter, addrs) + } if addr == nil { // should not happen err = &AddrError{"LookupHost returned invalid address", addrs[0]} diff --git a/src/pkg/net/lookup.go b/src/pkg/net/lookup.go index 7b2185ed4..eeb22a8ae 100644 --- a/src/pkg/net/lookup.go +++ b/src/pkg/net/lookup.go @@ -36,3 +36,15 @@ func LookupPort(network, service string) (port int, err os.Error) { } return } + +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err os.Error) { + cname, err, ok := cgoLookupCNAME(name) + if !ok { + cname, err = goLookupCNAME(name) + } + return +} diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go index 04a898a9a..51db10739 100644 --- a/src/pkg/net/net.go +++ b/src/pkg/net/net.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The net package provides a portable interface to Unix -// networks sockets, including TCP/IP, UDP, domain name -// resolution, and Unix domain sockets. +// Package net provides a portable interface to Unix networks sockets, +// including TCP/IP, UDP, domain name resolution, and Unix domain sockets. package net // TODO(rsc): diff --git a/src/pkg/net/resolv_windows.go b/src/pkg/net/resolv_windows.go index 000c30659..3506ea177 100644 --- a/src/pkg/net/resolv_windows.go +++ b/src/pkg/net/resolv_windows.go @@ -47,7 +47,7 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { return addrs, nil } -func LookupCNAME(name string) (cname string, err os.Error) { +func goLookupCNAME(name string) (cname string, err os.Error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) if int(e) != 0 { diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 37695a068..075748b83 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -108,12 +108,10 @@ func doTest(t *testing.T, network, listenaddr, dialaddr string) { } func TestTCPServer(t *testing.T) { - doTest(t, "tcp", "0.0.0.0", "127.0.0.1") - doTest(t, "tcp", "", "127.0.0.1") + doTest(t, "tcp", "127.0.0.1", "127.0.0.1") if kernelSupportsIPv6() { - doTest(t, "tcp", "[::]", "[::ffff:127.0.0.1]") - doTest(t, "tcp", "[::]", "127.0.0.1") - doTest(t, "tcp", "0.0.0.0", "[::ffff:127.0.0.1]") + doTest(t, "tcp", "[::1]", "[::1]") + doTest(t, "tcp", "127.0.0.1", "[::ffff:127.0.0.1]") } } diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index 933700af1..bd88f7ece 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -161,7 +161,7 @@ type UnknownSocketError struct { } func (e *UnknownSocketError) String() string { - return "unknown socket address type " + reflect.Typeof(e.sa).String() + return "unknown socket address type " + reflect.TypeOf(e.sa).String() } func sockaddrToString(sa syscall.Sockaddr) (name string, err os.Error) { diff --git a/src/pkg/net/srv_test.go b/src/pkg/net/srv_test.go index 4dd6089cd..f1c7a0ab4 100644 --- a/src/pkg/net/srv_test.go +++ b/src/pkg/net/srv_test.go @@ -8,10 +8,17 @@ package net import ( + "runtime" "testing" ) +var avoidMacFirewall = runtime.GOOS == "darwin" + func TestGoogleSRV(t *testing.T) { + if testing.Short() || avoidMacFirewall { + t.Logf("skipping test to avoid external network") + return + } _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") if err != nil { t.Errorf("failed: %s", err) diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index b484be20b..d9aa7cf19 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -62,8 +62,8 @@ func (a *TCPAddr) toAddr() sockaddr { // host:port and resolves domain names or port names to // numeric addresses. A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveTCPAddr(addr string) (*TCPAddr, os.Error) { - ip, port, err := hostPortToIP("tcp", addr) +func ResolveTCPAddr(network, addr string) (*TCPAddr, os.Error) { + ip, port, err := hostPortToIP(network, addr) if err != nil { return nil, err } diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go index fbfad9d61..9f19b5495 100644 --- a/src/pkg/net/textproto/textproto.go +++ b/src/pkg/net/textproto/textproto.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The textproto package implements generic support for -// text-based request/response protocols in the style of -// HTTP, NNTP, and SMTP. +// Package textproto implements generic support for text-based request/response +// protocols in the style of HTTP, NNTP, and SMTP. // // The package provides: // diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index 44d618dab..67684471b 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -62,8 +62,8 @@ func (a *UDPAddr) toAddr() sockaddr { // host:port and resolves domain names or port names to // numeric addresses. A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveUDPAddr(addr string) (*UDPAddr, os.Error) { - ip, port, err := hostPortToIP("udp", addr) +func ResolveUDPAddr(network, addr string) (*UDPAddr, os.Error) { + ip, port, err := hostPortToIP(network, addr) if err != nil { return nil, err } diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index 2209f04e8..1e5ccdb5c 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The netchan package implements type-safe networked channels: + Package netchan implements type-safe networked channels: it allows the two ends of a channel to appear on different computers connected by a network. It does this by transporting data sent to a channel on one machine so it can be recovered @@ -111,9 +111,9 @@ func (client *expClient) getChan(hdr *header, dir Dir) *netChan { // data arrives from the client. func (client *expClient) run() { hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) req := new(request) - reqValue := reflect.NewValue(req) + reqValue := reflect.ValueOf(req) error := new(error) for { *hdr = header{} @@ -221,7 +221,7 @@ func (client *expClient) serveSend(hdr header) { return } // Create a new value for each received item. - val := reflect.Zero(nch.ch.Type().Elem()) + val := reflect.New(nch.ch.Type().Elem()).Elem() if err := client.decode(val); err != nil { expLog("value decode:", err, "; type ", nch.ch.Type()) return @@ -341,7 +341,7 @@ func (exp *Exporter) Sync(timeout int64) os.Error { } func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { - chanType := reflect.Typeof(chT) + chanType := reflect.TypeOf(chT) if chanType.Kind() != reflect.Chan { return reflect.Value{}, os.ErrorString("not a channel") } @@ -359,7 +359,7 @@ func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { return reflect.Value{}, os.ErrorString("to import/export with Recv, must provide chan<-") } } - return reflect.NewValue(chT), nil + return reflect.ValueOf(chT), nil } // Export exports a channel of a given type and specified direction. The diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index 9921486bd..0a700ca2b 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -73,10 +73,10 @@ func (imp *Importer) shutdown() { func (imp *Importer) run() { // Loop on responses; requests are sent by ImportNValues() hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) ackHdr := new(header) err := new(error) - errValue := reflect.NewValue(err) + errValue := reflect.ValueOf(err) for { *hdr = header{} if e := imp.decode(hdrValue); e != nil { @@ -133,7 +133,7 @@ func (imp *Importer) run() { ackHdr.SeqNum = hdr.SeqNum imp.encode(ackHdr, payAck, nil) // Create a new value for each received item. - value := reflect.Zero(nch.ch.Type().Elem()) + value := reflect.New(nch.ch.Type().Elem()).Elem() if e := imp.decode(value); e != nil { impLog("importer value decode:", e) return diff --git a/src/pkg/os/file.go b/src/pkg/os/file.go index 3aad80234..dff8fa862 100644 --- a/src/pkg/os/file.go +++ b/src/pkg/os/file.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The os package provides a platform-independent interface to operating -// system functionality. The design is Unix-like. +// Package os provides a platform-independent interface to operating system +// functionality. The design is Unix-like. package os import ( "runtime" + "sync" "syscall" ) @@ -15,8 +16,9 @@ import ( type File struct { fd int name string - dirinfo *dirInfo // nil unless directory being read - nepipe int // number of consecutive EPIPE in Write + dirinfo *dirInfo // nil unless directory being read + nepipe int // number of consecutive EPIPE in Write + l sync.Mutex // used to implement windows pread/pwrite } // Fd returns the integer Unix file descriptor referencing the open file. @@ -30,7 +32,7 @@ func NewFile(fd int, name string) *File { if fd < 0 { return nil } - f := &File{fd, name, nil, 0} + f := &File{fd: fd, name: name} runtime.SetFinalizer(f, (*File).Close) return f } @@ -85,7 +87,7 @@ func (file *File) Read(b []byte) (n int, err Error) { if file == nil { return 0, EINVAL } - n, e := syscall.Read(file.fd, b) + n, e := file.read(b) if n < 0 { n = 0 } @@ -107,7 +109,7 @@ func (file *File) ReadAt(b []byte, off int64) (n int, err Error) { return 0, EINVAL } for len(b) > 0 { - m, e := syscall.Pread(file.fd, b, off) + m, e := file.pread(b, off) if m == 0 && !iserror(e) { return n, EOF } @@ -129,7 +131,7 @@ func (file *File) Write(b []byte) (n int, err Error) { if file == nil { return 0, EINVAL } - n, e := syscall.Write(file.fd, b) + n, e := file.write(b) if n < 0 { n = 0 } @@ -150,7 +152,7 @@ func (file *File) WriteAt(b []byte, off int64) (n int, err Error) { return 0, EINVAL } for len(b) > 0 { - m, e := syscall.Pwrite(file.fd, b, off) + m, e := file.pwrite(b, off) if iserror(e) { err = &PathError{"write", file.name, Errno(e)} break @@ -167,7 +169,7 @@ func (file *File) WriteAt(b []byte, off int64) (n int, err Error) { // relative to the current offset, and 2 means relative to the end. // It returns the new offset and an Error, if any. func (file *File) Seek(offset int64, whence int) (ret int64, err Error) { - r, e := syscall.Seek(file.fd, offset, whence) + r, e := file.seek(offset, whence) if !iserror(e) && file.dirinfo != nil && r != 0 { e = syscall.EISDIR } diff --git a/src/pkg/os/file_plan9.go b/src/pkg/os/file_plan9.go index c8d0efba4..7b473f802 100644 --- a/src/pkg/os/file_plan9.go +++ b/src/pkg/os/file_plan9.go @@ -117,6 +117,39 @@ func (f *File) Sync() (err Error) { return nil } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err syscall.Error) { + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to nil. +func (f *File) pread(b []byte, off int64) (n int, err syscall.Error) { + return syscall.Pread(f.fd, b, off) +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err syscall.Error) { + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err syscall.Error) { + return syscall.Pwrite(f.fd, b, off) +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err syscall.Error) { + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/src/pkg/os/file_posix.go b/src/pkg/os/file_posix.go index 5151df498..f1191d61f 100644 --- a/src/pkg/os/file_posix.go +++ b/src/pkg/os/file_posix.go @@ -10,11 +10,13 @@ import ( "syscall" ) +func sigpipe() // implemented in package runtime + func epipecheck(file *File, e int) { if e == syscall.EPIPE { file.nepipe++ if file.nepipe >= 10 { - Exit(syscall.EPIPE) + sigpipe() } } else { file.nepipe = 0 diff --git a/src/pkg/os/file_unix.go b/src/pkg/os/file_unix.go index f2b94f4c2..2fb28df65 100644 --- a/src/pkg/os/file_unix.go +++ b/src/pkg/os/file_unix.go @@ -96,6 +96,39 @@ func (file *File) Readdir(count int) (fi []FileInfo, err Error) { return } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err int) { + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to 0. +func (f *File) pread(b []byte, off int64) (n int, err int) { + return syscall.Pread(f.fd, b, off) +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err int) { + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err int) { + return syscall.Pwrite(f.fd, b, off) +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err int) { + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/src/pkg/os/file_windows.go b/src/pkg/os/file_windows.go index 862baf6b9..95f60b735 100644 --- a/src/pkg/os/file_windows.go +++ b/src/pkg/os/file_windows.go @@ -165,6 +165,77 @@ func (file *File) Readdir(count int) (fi []FileInfo, err Error) { return fi, nil } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to 0. +func (f *File) pread(b []byte, off int64) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + curoffset, e := syscall.Seek(f.fd, 0, 1) + if e != 0 { + return 0, e + } + defer syscall.Seek(f.fd, curoffset, 0) + o := syscall.Overlapped{ + OffsetHigh: uint32(off >> 32), + Offset: uint32(off), + } + var done uint32 + e = syscall.ReadFile(int32(f.fd), b, &done, &o) + if e != 0 { + return 0, e + } + return int(done), 0 +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + curoffset, e := syscall.Seek(f.fd, 0, 1) + if e != 0 { + return 0, e + } + defer syscall.Seek(f.fd, curoffset, 0) + o := syscall.Overlapped{ + OffsetHigh: uint32(off >> 32), + Offset: uint32(off), + } + var done uint32 + e = syscall.WriteFile(int32(f.fd), b, &done, &o) + if e != 0 { + return 0, e + } + return int(done), 0 +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err int) { + f.l.Lock() + defer f.l.Unlock() + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/src/pkg/os/inotify/inotify_linux.go b/src/pkg/os/inotify/inotify_linux.go index 8b5c30e0d..7c7b7698f 100644 --- a/src/pkg/os/inotify/inotify_linux.go +++ b/src/pkg/os/inotify/inotify_linux.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -This package implements a wrapper for the Linux inotify system. +Package inotify implements a wrapper for the Linux inotify system. Example: watcher, err := inotify.NewWatcher() diff --git a/src/pkg/os/os_test.go b/src/pkg/os/os_test.go index 551b86508..65475c118 100644 --- a/src/pkg/os/os_test.go +++ b/src/pkg/os/os_test.go @@ -567,8 +567,8 @@ func checkSize(t *testing.T, f *File, size int64) { } } -func TestTruncate(t *testing.T) { - f := newFile("TestTruncate", t) +func TestFTruncate(t *testing.T) { + f := newFile("TestFTruncate", t) defer Remove(f.Name()) defer f.Close() @@ -585,6 +585,24 @@ func TestTruncate(t *testing.T) { checkSize(t, f, 13+9) // wrote at offset past where hello, world was. } +func TestTruncate(t *testing.T) { + f := newFile("TestTruncate", t) + defer Remove(f.Name()) + defer f.Close() + + checkSize(t, f, 0) + f.Write([]byte("hello, world\n")) + checkSize(t, f, 13) + Truncate(f.Name(), 10) + checkSize(t, f, 10) + Truncate(f.Name(), 1024) + checkSize(t, f, 1024) + Truncate(f.Name(), 0) + checkSize(t, f, 0) + f.Write([]byte("surprise!")) + checkSize(t, f, 13+9) // wrote at offset past where hello, world was. +} + // Use TempDir() to make sure we're on a local file system, // so that timings are not distorted by latency and caching. // On NFS, timings can be off due to caching of meta-data on @@ -886,6 +904,18 @@ func TestAppend(t *testing.T) { if s != "new|append" { t.Fatalf("writeFile: have %q want %q", s, "new|append") } + s = writeFile(t, f, O_CREATE|O_APPEND|O_RDWR, "|append") + if s != "new|append|append" { + t.Fatalf("writeFile: have %q want %q", s, "new|append|append") + } + err := Remove(f) + if err != nil { + t.Fatalf("Remove: %v", err) + } + s = writeFile(t, f, O_CREATE|O_APPEND|O_RDWR, "new&append") + if s != "new&append" { + t.Fatalf("writeFile: have %q want %q", s, "new&append") + } } func TestStatDirWithTrailingSlash(t *testing.T) { diff --git a/src/pkg/os/user/Makefile b/src/pkg/os/user/Makefile new file mode 100644 index 000000000..731f7999a --- /dev/null +++ b/src/pkg/os/user/Makefile @@ -0,0 +1,26 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=os/user +GOFILES=\ + user.go\ + +ifneq ($(GOARCH),arm) +CGOFILES_linux=\ + lookup_unix.go +CGOFILES_freebsd=\ + lookup_unix.go +CGOFILES_darwin=\ + lookup_unix.go +endif + +ifneq ($(CGOFILES_$(GOOS)),) +CGOFILES+=$(CGOFILES_$(GOOS)) +else +GOFILES+=lookup_stubs.go +endif + +include ../../../Make.pkg diff --git a/src/pkg/os/user/lookup_stubs.go b/src/pkg/os/user/lookup_stubs.go new file mode 100644 index 000000000..2f08f70fd --- /dev/null +++ b/src/pkg/os/user/lookup_stubs.go @@ -0,0 +1,19 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "os" + "runtime" +) + +func Lookup(username string) (*User, os.Error) { + return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func LookupId(int) (*User, os.Error) { + return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/src/pkg/os/user/lookup_unix.go b/src/pkg/os/user/lookup_unix.go new file mode 100644 index 000000000..678de802b --- /dev/null +++ b/src/pkg/os/user/lookup_unix.go @@ -0,0 +1,104 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "os" + "runtime" + "strings" + "unsafe" +) + +/* +#include <unistd.h> +#include <sys/types.h> +#include <pwd.h> +#include <stdlib.h> + +static int mygetpwuid_r(int uid, struct passwd *pwd, + char *buf, size_t buflen, struct passwd **result) { + return getpwuid_r(uid, pwd, buf, buflen, result); +} +*/ +import "C" + +// Lookup looks up a user by username. If the user cannot be found, +// the returned error is of type UnknownUserError. +func Lookup(username string) (*User, os.Error) { + return lookup(-1, username, true) +} + +// LookupId looks up a user by userid. If the user cannot be found, +// the returned error is of type UnknownUserIdError. +func LookupId(uid int) (*User, os.Error) { + return lookup(uid, "", false) +} + +func lookup(uid int, username string, lookupByName bool) (*User, os.Error) { + var pwd C.struct_passwd + var result *C.struct_passwd + + var bufSize C.long + if runtime.GOOS == "freebsd" { + // FreeBSD doesn't have _SC_GETPW_R_SIZE_MAX + // and just returns -1. So just use the same + // size that Linux returns + bufSize = 1024 + } else { + bufSize = C.sysconf(C._SC_GETPW_R_SIZE_MAX) + if bufSize <= 0 || bufSize > 1<<20 { + return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize) + } + } + buf := C.malloc(C.size_t(bufSize)) + defer C.free(buf) + var rv C.int + if lookupByName { + nameC := C.CString(username) + defer C.free(unsafe.Pointer(nameC)) + rv = C.getpwnam_r(nameC, + &pwd, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup username %s: %s", username, os.Errno(rv)) + } + if result == nil { + return nil, UnknownUserError(username) + } + } else { + // mygetpwuid_r is a wrapper around getpwuid_r to + // to avoid using uid_t because C.uid_t(uid) for + // unknown reasons doesn't work on linux. + rv = C.mygetpwuid_r(C.int(uid), + &pwd, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup userid %d: %s", uid, os.Errno(rv)) + } + if result == nil { + return nil, UnknownUserIdError(uid) + } + } + u := &User{ + Uid: int(pwd.pw_uid), + Gid: int(pwd.pw_gid), + Username: C.GoString(pwd.pw_name), + Name: C.GoString(pwd.pw_gecos), + HomeDir: C.GoString(pwd.pw_dir), + } + // The pw_gecos field isn't quite standardized. Some docs + // say: "It is expected to be a comma separated list of + // personal data where the first item is the full name of the + // user." + if i := strings.Index(u.Name, ","); i >= 0 { + u.Name = u.Name[:i] + } + return u, nil +} diff --git a/src/pkg/os/user/user.go b/src/pkg/os/user/user.go new file mode 100644 index 000000000..dd009211d --- /dev/null +++ b/src/pkg/os/user/user.go @@ -0,0 +1,35 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package user allows user account lookups by name or id. +package user + +import ( + "strconv" +) + +// User represents a user account. +type User struct { + Uid int // user id + Gid int // primary group id + Username string + Name string + HomeDir string +} + +// UnknownUserIdError is returned by LookupId when +// a user cannot be found. +type UnknownUserIdError int + +func (e UnknownUserIdError) String() string { + return "user: unknown userid " + strconv.Itoa(int(e)) +} + +// UnknownUserError is returned by Lookup when +// a user cannot be found. +type UnknownUserError string + +func (e UnknownUserError) String() string { + return "user: unknown user " + string(e) +} diff --git a/src/pkg/os/user/user_test.go b/src/pkg/os/user/user_test.go new file mode 100644 index 000000000..2c142bf18 --- /dev/null +++ b/src/pkg/os/user/user_test.go @@ -0,0 +1,61 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "os" + "reflect" + "runtime" + "syscall" + "testing" +) + +func skip(t *testing.T) bool { + if runtime.GOARCH == "arm" { + t.Logf("user: cgo not implemented on arm; skipping tests") + return true + } + + if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" || runtime.GOOS == "darwin" { + return false + } + + t.Logf("user: Lookup not implemented on %s; skipping test", runtime.GOOS) + return true +} + +func TestLookup(t *testing.T) { + if skip(t) { + return + } + + // Test LookupId on the current user + uid := syscall.Getuid() + u, err := LookupId(uid) + if err != nil { + t.Fatalf("LookupId: %v", err) + } + if e, g := uid, u.Uid; e != g { + t.Errorf("expected Uid of %d; got %d", e, g) + } + fi, err := os.Stat(u.HomeDir) + if err != nil || !fi.IsDirectory() { + t.Errorf("expected a valid HomeDir; stat(%q): err=%v, IsDirectory=%v", err, fi.IsDirectory()) + } + if u.Username == "" { + t.Fatalf("didn't get a username") + } + + // Test Lookup by username, using the username from LookupId + un, err := Lookup(u.Username) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if !reflect.DeepEqual(u, un) { + t.Errorf("Lookup by userid vs. name didn't match\n"+ + "LookupId(%d): %#v\n"+ + "Lookup(%q): %#v\n",uid, u, u.Username, un) + } +} diff --git a/src/pkg/path/filepath/path.go b/src/pkg/path/filepath/path.go index de673a725..541a23306 100644 --- a/src/pkg/path/filepath/path.go +++ b/src/pkg/path/filepath/path.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The filepath package implements utility routines for manipulating -// filename paths in a way compatible with the target operating -// system-defined file paths. +// Package filepath implements utility routines for manipulating filename paths +// in a way compatible with the target operating system-defined file paths. package filepath import ( diff --git a/src/pkg/path/path.go b/src/pkg/path/path.go index 658eec093..235384667 100644 --- a/src/pkg/path/path.go +++ b/src/pkg/path/path.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The path package implements utility routines for manipulating -// slash-separated filename paths. +// Package path implements utility routines for manipulating slash-separated +// filename paths. package path import ( diff --git a/src/pkg/reflect/all_test.go b/src/pkg/reflect/all_test.go index bc9157672..5bf65333c 100644 --- a/src/pkg/reflect/all_test.go +++ b/src/pkg/reflect/all_test.go @@ -5,11 +5,13 @@ package reflect_test import ( + "bytes" "container/vector" "fmt" "io" "os" . "reflect" + "runtime" "testing" "unsafe" ) @@ -35,7 +37,7 @@ func assert(t *testing.T, s, want string) { } } -func typestring(i interface{}) string { return Typeof(i).String() } +func typestring(i interface{}) string { return TypeOf(i).String() } var typeTests = []pair{ {struct{ x int }{}, "int"}, @@ -150,50 +152,50 @@ var typeTests = []pair{ b() }) }{}, - "interface { a(func(func(int) int) func(func(int)) int); b() }", + "interface { reflect_test.a(func(func(int) int) func(func(int)) int); reflect_test.b() }", }, } var valueTests = []pair{ - {(int8)(0), "8"}, - {(int16)(0), "16"}, - {(int32)(0), "32"}, - {(int64)(0), "64"}, - {(uint8)(0), "8"}, - {(uint16)(0), "16"}, - {(uint32)(0), "32"}, - {(uint64)(0), "64"}, - {(float32)(0), "256.25"}, - {(float64)(0), "512.125"}, - {(string)(""), "stringy cheese"}, - {(bool)(false), "true"}, - {(*int8)(nil), "*int8(0)"}, - {(**int8)(nil), "**int8(0)"}, - {[5]int32{}, "[5]int32{0, 0, 0, 0, 0}"}, - {(**integer)(nil), "**reflect_test.integer(0)"}, - {(map[string]int32)(nil), "map[string] int32{<can't iterate on maps>}"}, - {(chan<- string)(nil), "chan<- string"}, - {struct { + {new(int8), "8"}, + {new(int16), "16"}, + {new(int32), "32"}, + {new(int64), "64"}, + {new(uint8), "8"}, + {new(uint16), "16"}, + {new(uint32), "32"}, + {new(uint64), "64"}, + {new(float32), "256.25"}, + {new(float64), "512.125"}, + {new(string), "stringy cheese"}, + {new(bool), "true"}, + {new(*int8), "*int8(0)"}, + {new(**int8), "**int8(0)"}, + {new([5]int32), "[5]int32{0, 0, 0, 0, 0}"}, + {new(**integer), "**reflect_test.integer(0)"}, + {new(map[string]int32), "map[string] int32{<can't iterate on maps>}"}, + {new(chan<- string), "chan<- string"}, + {new(func(a int8, b int32)), "func(int8, int32)(0)"}, + {new(struct { c chan *int32 d float32 - }{}, + }), "struct { c chan *int32; d float32 }{chan *int32, 0}", }, - {(func(a int8, b int32))(nil), "func(int8, int32)(0)"}, - {struct{ c func(chan *integer, *int8) }{}, + {new(struct{ c func(chan *integer, *int8) }), "struct { c func(chan *reflect_test.integer, *int8) }{func(chan *reflect_test.integer, *int8)(0)}", }, - {struct { + {new(struct { a int8 b int32 - }{}, + }), "struct { a int8; b int32 }{0, 0}", }, - {struct { + {new(struct { a int8 b int8 c int32 - }{}, + }), "struct { a int8; b int8; c int32 }{0, 0, 0}", }, } @@ -207,13 +209,13 @@ func testType(t *testing.T, i int, typ Type, want string) { func TestTypes(t *testing.T) { for i, tt := range typeTests { - testType(t, i, NewValue(tt.i).Field(0).Type(), tt.s) + testType(t, i, ValueOf(tt.i).Field(0).Type(), tt.s) } } func TestSet(t *testing.T) { for i, tt := range valueTests { - v := NewValue(tt.i) + v := ValueOf(tt.i).Elem() switch v.Kind() { case Int: v.SetInt(132) @@ -257,40 +259,40 @@ func TestSet(t *testing.T) { func TestSetValue(t *testing.T) { for i, tt := range valueTests { - v := NewValue(tt.i) + v := ValueOf(tt.i).Elem() switch v.Kind() { case Int: - v.Set(NewValue(int(132))) + v.Set(ValueOf(int(132))) case Int8: - v.Set(NewValue(int8(8))) + v.Set(ValueOf(int8(8))) case Int16: - v.Set(NewValue(int16(16))) + v.Set(ValueOf(int16(16))) case Int32: - v.Set(NewValue(int32(32))) + v.Set(ValueOf(int32(32))) case Int64: - v.Set(NewValue(int64(64))) + v.Set(ValueOf(int64(64))) case Uint: - v.Set(NewValue(uint(132))) + v.Set(ValueOf(uint(132))) case Uint8: - v.Set(NewValue(uint8(8))) + v.Set(ValueOf(uint8(8))) case Uint16: - v.Set(NewValue(uint16(16))) + v.Set(ValueOf(uint16(16))) case Uint32: - v.Set(NewValue(uint32(32))) + v.Set(ValueOf(uint32(32))) case Uint64: - v.Set(NewValue(uint64(64))) + v.Set(ValueOf(uint64(64))) case Float32: - v.Set(NewValue(float32(256.25))) + v.Set(ValueOf(float32(256.25))) case Float64: - v.Set(NewValue(512.125)) + v.Set(ValueOf(512.125)) case Complex64: - v.Set(NewValue(complex64(532.125 + 10i))) + v.Set(ValueOf(complex64(532.125 + 10i))) case Complex128: - v.Set(NewValue(complex128(564.25 + 1i))) + v.Set(ValueOf(complex128(564.25 + 1i))) case String: - v.Set(NewValue("stringy cheese")) + v.Set(ValueOf("stringy cheese")) case Bool: - v.Set(NewValue(true)) + v.Set(ValueOf(true)) } s := valueToString(v) if s != tt.s { @@ -316,7 +318,7 @@ var valueToStringTests = []pair{ func TestValueToString(t *testing.T) { for i, test := range valueToStringTests { - s := valueToString(NewValue(test.i)) + s := valueToString(ValueOf(test.i)) if s != test.s { t.Errorf("#%d: have %#q, want %#q", i, s, test.s) } @@ -324,7 +326,7 @@ func TestValueToString(t *testing.T) { } func TestArrayElemSet(t *testing.T) { - v := NewValue([10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + v := ValueOf(&[10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Elem() v.Index(4).SetInt(123) s := valueToString(v) const want = "[10]int{1, 2, 3, 4, 123, 6, 7, 8, 9, 10}" @@ -332,7 +334,7 @@ func TestArrayElemSet(t *testing.T) { t.Errorf("[10]int: have %#q want %#q", s, want) } - v = NewValue([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + v = ValueOf([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) v.Index(4).SetInt(123) s = valueToString(v) const want1 = "[]int{1, 2, 3, 4, 123, 6, 7, 8, 9, 10}" @@ -344,15 +346,15 @@ func TestArrayElemSet(t *testing.T) { func TestPtrPointTo(t *testing.T) { var ip *int32 var i int32 = 1234 - vip := NewValue(&ip) - vi := NewValue(i) + vip := ValueOf(&ip) + vi := ValueOf(&i).Elem() vip.Elem().Set(vi.Addr()) if *ip != 1234 { t.Errorf("got %d, want 1234", *ip) } ip = nil - vp := NewValue(ip) + vp := ValueOf(&ip).Elem() vp.Set(Zero(vp.Type())) if ip != nil { t.Errorf("got non-nil (%p), want nil", ip) @@ -362,7 +364,7 @@ func TestPtrPointTo(t *testing.T) { func TestPtrSetNil(t *testing.T) { var i int32 = 1234 ip := &i - vip := NewValue(&ip) + vip := ValueOf(&ip) vip.Elem().Set(Zero(vip.Elem().Type())) if ip != nil { t.Errorf("got non-nil (%d), want nil", *ip) @@ -371,7 +373,7 @@ func TestPtrSetNil(t *testing.T) { func TestMapSetNil(t *testing.T) { m := make(map[string]int) - vm := NewValue(&m) + vm := ValueOf(&m) vm.Elem().Set(Zero(vm.Elem().Type())) if m != nil { t.Errorf("got non-nil (%p), want nil", m) @@ -380,10 +382,10 @@ func TestMapSetNil(t *testing.T) { func TestAll(t *testing.T) { - testType(t, 1, Typeof((int8)(0)), "int8") - testType(t, 2, Typeof((*int8)(nil)).Elem(), "int8") + testType(t, 1, TypeOf((int8)(0)), "int8") + testType(t, 2, TypeOf((*int8)(nil)).Elem(), "int8") - typ := Typeof((*struct { + typ := TypeOf((*struct { c chan *int32 d float32 })(nil)) @@ -405,22 +407,22 @@ func TestAll(t *testing.T) { t.Errorf("FieldByName says absent field is present") } - typ = Typeof([32]int32{}) + typ = TypeOf([32]int32{}) testType(t, 7, typ, "[32]int32") testType(t, 8, typ.Elem(), "int32") - typ = Typeof((map[string]*int32)(nil)) + typ = TypeOf((map[string]*int32)(nil)) testType(t, 9, typ, "map[string] *int32") mtyp := typ testType(t, 10, mtyp.Key(), "string") testType(t, 11, mtyp.Elem(), "*int32") - typ = Typeof((chan<- string)(nil)) + typ = TypeOf((chan<- string)(nil)) testType(t, 12, typ, "chan<- string") testType(t, 13, typ.Elem(), "string") // make sure tag strings are not part of element type - typ = Typeof(struct { + typ = TypeOf(struct { d []uint32 "TAG" }{}).Field(0).Type testType(t, 14, typ, "[]uint32") @@ -428,23 +430,23 @@ func TestAll(t *testing.T) { func TestInterfaceGet(t *testing.T) { var inter struct { - e interface{} + E interface{} } - inter.e = 123.456 - v1 := NewValue(&inter) + inter.E = 123.456 + v1 := ValueOf(&inter) v2 := v1.Elem().Field(0) assert(t, v2.Type().String(), "interface { }") i2 := v2.Interface() - v3 := NewValue(i2) + v3 := ValueOf(i2) assert(t, v3.Type().String(), "float64") } func TestInterfaceValue(t *testing.T) { var inter struct { - e interface{} + E interface{} } - inter.e = 123.456 - v1 := NewValue(&inter) + inter.E = 123.456 + v1 := ValueOf(&inter) v2 := v1.Elem().Field(0) assert(t, v2.Type().String(), "interface { }") v3 := v2.Elem() @@ -452,13 +454,14 @@ func TestInterfaceValue(t *testing.T) { i3 := v2.Interface() if _, ok := i3.(float64); !ok { - t.Error("v2.Interface() did not return float64, got ", Typeof(i3)) + t.Error("v2.Interface() did not return float64, got ", TypeOf(i3)) } } func TestFunctionValue(t *testing.T) { - v := NewValue(func() {}) - if v.Interface() != v.Interface() { + var x interface{} = func() {} + v := ValueOf(x) + if v.Interface() != v.Interface() || v.Interface() != x { t.Fatalf("TestFunction != itself") } assert(t, v.Type().String(), "func()") @@ -471,6 +474,18 @@ var appendTests = []struct { {make([]int, 2, 4), []int{22, 33, 44}}, } +func sameInts(x, y []int) bool { + if len(x) != len(y) { + return false + } + for i, xx := range x { + if xx != y[i] { + return false + } + } + return true +} + func TestAppend(t *testing.T) { for i, test := range appendTests { origLen, extraLen := len(test.orig), len(test.extra) @@ -478,15 +493,15 @@ func TestAppend(t *testing.T) { // Convert extra from []int to []Value. e0 := make([]Value, len(test.extra)) for j, e := range test.extra { - e0[j] = NewValue(e) + e0[j] = ValueOf(e) } // Convert extra from []int to *SliceValue. - e1 := NewValue(test.extra) + e1 := ValueOf(test.extra) // Test Append. - a0 := NewValue(test.orig) + a0 := ValueOf(test.orig) have0 := Append(a0, e0...).Interface().([]int) - if !DeepEqual(have0, want) { - t.Errorf("Append #%d: have %v, want %v", i, have0, want) + if !sameInts(have0, want) { + t.Errorf("Append #%d: have %v, want %v (%p %p)", i, have0, want, test.orig, have0) } // Check that the orig and extra slices were not modified. if len(test.orig) != origLen { @@ -496,9 +511,9 @@ func TestAppend(t *testing.T) { t.Errorf("Append #%d extraLen: have %v, want %v", i, len(test.extra), extraLen) } // Test AppendSlice. - a1 := NewValue(test.orig) + a1 := ValueOf(test.orig) have1 := AppendSlice(a1, e1).Interface().([]int) - if !DeepEqual(have1, want) { + if !sameInts(have1, want) { t.Errorf("AppendSlice #%d: have %v, want %v", i, have1, want) } // Check that the orig and extra slices were not modified. @@ -520,8 +535,10 @@ func TestCopy(t *testing.T) { t.Fatalf("b != c before test") } } - aa := NewValue(a) - ab := NewValue(b) + a1 := a + b1 := b + aa := ValueOf(&a1).Elem() + ab := ValueOf(&b1).Elem() for tocopy := 1; tocopy <= 7; tocopy++ { aa.SetLen(tocopy) Copy(ab, aa) @@ -548,14 +565,41 @@ func TestCopy(t *testing.T) { } } +func TestCopyArray(t *testing.T) { + a := [8]int{1, 2, 3, 4, 10, 9, 8, 7} + b := [11]int{11, 22, 33, 44, 1010, 99, 88, 77, 66, 55, 44} + c := b + aa := ValueOf(&a).Elem() + ab := ValueOf(&b).Elem() + Copy(ab, aa) + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + t.Errorf("(i) a[%d]=%d, b[%d]=%d", i, a[i], i, b[i]) + } + } + for i := len(a); i < len(b); i++ { + if b[i] != c[i] { + if i < len(a) { + t.Errorf("(ii) a[%d]=%d, b[%d]=%d, c[%d]=%d", + i, a[i], i, b[i], i, c[i]) + } else { + t.Errorf("(iii) b[%d]=%d, c[%d]=%d", + i, b[i], i, c[i]) + } + } else { + t.Logf("elem %d is okay\n", i) + } + } +} + func TestBigUnnamedStruct(t *testing.T) { b := struct{ a, b, c, d int64 }{1, 2, 3, 4} - v := NewValue(b) + v := ValueOf(b) b1 := v.Interface().(struct { a, b, c, d int64 }) if b1.a != b.a || b1.b != b.b || b1.c != b.c || b1.d != b.d { - t.Errorf("NewValue(%v).Interface().(*Big) = %v", b, b1) + t.Errorf("ValueOf(%v).Interface().(*Big) = %v", b, b1) } } @@ -565,10 +609,10 @@ type big struct { func TestBigStruct(t *testing.T) { b := big{1, 2, 3, 4, 5} - v := NewValue(b) + v := ValueOf(b) b1 := v.Interface().(big) if b1.a != b.a || b1.b != b.b || b1.c != b.c || b1.d != b.d || b1.e != b.e { - t.Errorf("NewValue(%v).Interface().(big) = %v", b, b1) + t.Errorf("ValueOf(%v).Interface().(big) = %v", b, b1) } } @@ -632,15 +676,15 @@ func TestDeepEqual(t *testing.T) { } } -func TestTypeof(t *testing.T) { +func TestTypeOf(t *testing.T) { for _, test := range deepEqualTests { - v := NewValue(test.a) + v := ValueOf(test.a) if !v.IsValid() { continue } - typ := Typeof(test.a) + typ := TypeOf(test.a) if typ != v.Type() { - t.Errorf("Typeof(%v) = %v, but NewValue(%v).Type() = %v", test.a, typ, test.a, v.Type()) + t.Errorf("TypeOf(%v) = %v, but ValueOf(%v).Type() = %v", test.a, typ, test.a, v.Type()) } } } @@ -690,7 +734,7 @@ func TestDeepEqualComplexStructInequality(t *testing.T) { func check2ndField(x interface{}, offs uintptr, t *testing.T) { - s := NewValue(x) + s := ValueOf(x) f := s.Type().Field(1) if f.Offset != offs { t.Error("mismatched offsets in structure alignment:", f.Offset, offs) @@ -723,16 +767,16 @@ func TestAlignment(t *testing.T) { } func Nil(a interface{}, t *testing.T) { - n := NewValue(a).Field(0) + n := ValueOf(a).Field(0) if !n.IsNil() { t.Errorf("%v should be nil", a) } } func NotNil(a interface{}, t *testing.T) { - n := NewValue(a).Field(0) + n := ValueOf(a).Field(0) if n.IsNil() { - t.Errorf("value of type %v should not be nil", NewValue(a).Type().String()) + t.Errorf("value of type %v should not be nil", ValueOf(a).Type().String()) } } @@ -748,7 +792,7 @@ func TestIsNil(t *testing.T) { struct{ x []string }{}, } for _, ts := range doNil { - ty := Typeof(ts).Field(0).Type + ty := TypeOf(ts).Field(0).Type v := Zero(ty) v.IsNil() // panics if not okay to call } @@ -803,50 +847,22 @@ func TestInterfaceExtraction(t *testing.T) { } s.w = os.Stdout - v := Indirect(NewValue(&s)).Field(0).Interface() + v := Indirect(ValueOf(&s)).Field(0).Interface() if v != s.w.(interface{}) { t.Error("Interface() on interface: ", v, s.w) } } -func TestInterfaceEditing(t *testing.T) { - // strings are bigger than one word, - // so the interface conversion allocates - // memory to hold a string and puts that - // pointer in the interface. - var i interface{} = "hello" - - // if i pass the interface value by value - // to NewValue, i should get a fresh copy - // of the value. - v := NewValue(i) - - // and setting that copy to "bye" should - // not change the value stored in i. - v.SetString("bye") - if i.(string) != "hello" { - t.Errorf(`Set("bye") changed i to %s`, i.(string)) - } - - // the same should be true of smaller items. - i = 123 - v = NewValue(i) - v.SetInt(234) - if i.(int) != 123 { - t.Errorf("Set(234) changed i to %d", i.(int)) - } -} - func TestNilPtrValueSub(t *testing.T) { var pi *int - if pv := NewValue(pi); pv.Elem().IsValid() { - t.Error("NewValue((*int)(nil)).Elem().IsValid()") + if pv := ValueOf(pi); pv.Elem().IsValid() { + t.Error("ValueOf((*int)(nil)).Elem().IsValid()") } } func TestMap(t *testing.T) { m := map[string]int{"a": 1, "b": 2} - mv := NewValue(m) + mv := ValueOf(m) if n := mv.Len(); n != len(m) { t.Errorf("Len = %d, want %d", n, len(m)) } @@ -866,15 +882,15 @@ func TestMap(t *testing.T) { i++ // Check that value lookup is correct. - vv := mv.MapIndex(NewValue(k)) + vv := mv.MapIndex(ValueOf(k)) if vi := vv.Int(); vi != int64(v) { t.Errorf("Key %q: have value %d, want %d", k, vi, v) } // Copy into new map. - newmap.SetMapIndex(NewValue(k), NewValue(v)) + newmap.SetMapIndex(ValueOf(k), ValueOf(v)) } - vv := mv.MapIndex(NewValue("not-present")) + vv := mv.MapIndex(ValueOf("not-present")) if vv.IsValid() { t.Errorf("Invalid key: got non-nil value %s", valueToString(vv)) } @@ -891,13 +907,13 @@ func TestMap(t *testing.T) { } } - newmap.SetMapIndex(NewValue("a"), Value{}) + newmap.SetMapIndex(ValueOf("a"), Value{}) v, ok := newm["a"] if ok { t.Errorf("newm[\"a\"] = %d after delete", v) } - mv = NewValue(&m).Elem() + mv = ValueOf(&m).Elem() mv.Set(Zero(mv.Type())) if m != nil { t.Errorf("mv.Set(nil) failed") @@ -913,14 +929,14 @@ func TestChan(t *testing.T) { switch loop { case 1: c = make(chan int, 1) - cv = NewValue(c) + cv = ValueOf(c) case 0: - cv = MakeChan(Typeof(c), 1) + cv = MakeChan(TypeOf(c), 1) c = cv.Interface().(chan int) } // Send - cv.Send(NewValue(2)) + cv.Send(ValueOf(2)) if i := <-c; i != 2 { t.Errorf("reflect Send 2, native recv %d", i) } @@ -948,14 +964,14 @@ func TestChan(t *testing.T) { // TrySend fail c <- 100 - ok = cv.TrySend(NewValue(5)) + ok = cv.TrySend(ValueOf(5)) i := <-c if ok { t.Errorf("TrySend on full chan succeeded: value %d", i) } // TrySend success - ok = cv.TrySend(NewValue(6)) + ok = cv.TrySend(ValueOf(6)) if !ok { t.Errorf("TrySend on empty chan failed") } else { @@ -977,17 +993,17 @@ func TestChan(t *testing.T) { // check creation of unbuffered channel var c chan int - cv := MakeChan(Typeof(c), 0) + cv := MakeChan(TypeOf(c), 0) c = cv.Interface().(chan int) - if cv.TrySend(NewValue(7)) { + if cv.TrySend(ValueOf(7)) { t.Errorf("TrySend on sync chan succeeded") } if v, ok := cv.TryRecv(); v.IsValid() || ok { - t.Errorf("TryRecv on sync chan succeeded") + t.Errorf("TryRecv on sync chan succeeded: isvalid=%v ok=%v", v.IsValid(), ok) } // len/cap - cv = MakeChan(Typeof(c), 10) + cv = MakeChan(TypeOf(c), 10) c = cv.Interface().(chan int) for i := 0; i < 3; i++ { c <- i @@ -1005,7 +1021,7 @@ func dummy(b byte, c int, d byte) (i byte, j int, k byte) { } func TestFunc(t *testing.T) { - ret := NewValue(dummy).Call([]Value{NewValue(byte(10)), NewValue(20), NewValue(byte(30))}) + ret := ValueOf(dummy).Call([]Value{ValueOf(byte(10)), ValueOf(20), ValueOf(byte(30))}) if len(ret) != 3 { t.Fatalf("Call returned %d values, want 3", len(ret)) } @@ -1022,50 +1038,47 @@ type Point struct { x, y int } -func (p Point) Dist(scale int) int { return p.x*p.x*scale + p.y*p.y*scale } +func (p Point) Dist(scale int) int { + // println("Point.Dist", p.x, p.y, scale) + return p.x*p.x*scale + p.y*p.y*scale +} func TestMethod(t *testing.T) { // Non-curried method of type. p := Point{3, 4} - i := Typeof(p).Method(0).Func.Call([]Value{NewValue(p), NewValue(10)})[0].Int() + i := TypeOf(p).Method(0).Func.Call([]Value{ValueOf(p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Type Method returned %d; want 250", i) } - i = Typeof(&p).Method(0).Func.Call([]Value{NewValue(&p), NewValue(10)})[0].Int() + i = TypeOf(&p).Method(0).Func.Call([]Value{ValueOf(&p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Pointer Type Method returned %d; want 250", i) } // Curried method of value. - i = NewValue(p).Method(0).Call([]Value{NewValue(10)})[0].Int() + i = ValueOf(p).Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } // Curried method of pointer. - i = NewValue(&p).Method(0).Call([]Value{NewValue(10)})[0].Int() - if i != 250 { - t.Errorf("Value Method returned %d; want 250", i) - } - - // Curried method of pointer to value. - i = NewValue(p).Addr().Method(0).Call([]Value{NewValue(10)})[0].Int() + i = ValueOf(&p).Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } // Curried method of interface value. // Have to wrap interface value in a struct to get at it. - // Passing it to NewValue directly would + // Passing it to ValueOf directly would // access the underlying Point, not the interface. var s = struct { - x interface { + X interface { Dist(int) int } }{p} - pv := NewValue(s).Field(0) - i = pv.Method(0).Call([]Value{NewValue(10)})[0].Int() + pv := ValueOf(s).Field(0) + i = pv.Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } @@ -1080,19 +1093,19 @@ func TestInterfaceSet(t *testing.T) { Dist(int) int } } - sv := NewValue(&s).Elem() - sv.Field(0).Set(NewValue(p)) + sv := ValueOf(&s).Elem() + sv.Field(0).Set(ValueOf(p)) if q := s.I.(*Point); q != p { t.Errorf("i: have %p want %p", q, p) } pv := sv.Field(1) - pv.Set(NewValue(p)) + pv.Set(ValueOf(p)) if q := s.P.(*Point); q != p { t.Errorf("i: have %p want %p", q, p) } - i := pv.Method(0).Call([]Value{NewValue(10)})[0].Int() + i := pv.Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } @@ -1107,7 +1120,7 @@ func TestAnonymousFields(t *testing.T) { var field StructField var ok bool var t1 T1 - type1 := Typeof(t1) + type1 := TypeOf(t1) if field, ok = type1.FieldByName("int"); !ok { t.Error("no field 'int'") } @@ -1191,7 +1204,7 @@ var fieldTests = []FTest{ func TestFieldByIndex(t *testing.T) { for _, test := range fieldTests { - s := Typeof(test.s) + s := TypeOf(test.s) f := s.FieldByIndex(test.index) if f.Name != "" { if test.index != nil { @@ -1206,7 +1219,7 @@ func TestFieldByIndex(t *testing.T) { } if test.value != 0 { - v := NewValue(test.s).FieldByIndex(test.index) + v := ValueOf(test.s).FieldByIndex(test.index) if v.IsValid() { if x, ok := v.Interface().(int); ok { if x != test.value { @@ -1224,7 +1237,7 @@ func TestFieldByIndex(t *testing.T) { func TestFieldByName(t *testing.T) { for _, test := range fieldTests { - s := Typeof(test.s) + s := TypeOf(test.s) f, found := s.FieldByName(test.name) if found { if test.index != nil { @@ -1246,7 +1259,7 @@ func TestFieldByName(t *testing.T) { } if test.value != 0 { - v := NewValue(test.s).FieldByName(test.name) + v := ValueOf(test.s).FieldByName(test.name) if v.IsValid() { if x, ok := v.Interface().(int); ok { if x != test.value { @@ -1263,19 +1276,19 @@ func TestFieldByName(t *testing.T) { } func TestImportPath(t *testing.T) { - if path := Typeof(vector.Vector{}).PkgPath(); path != "container/vector" { - t.Errorf("Typeof(vector.Vector{}).PkgPath() = %q, want \"container/vector\"", path) + if path := TypeOf(vector.Vector{}).PkgPath(); path != "container/vector" { + t.Errorf("TypeOf(vector.Vector{}).PkgPath() = %q, want \"container/vector\"", path) } } func TestDotDotDot(t *testing.T) { // Test example from FuncType.DotDotDot documentation. var f func(x int, y ...float64) - typ := Typeof(f) - if typ.NumIn() == 2 && typ.In(0) == Typeof(int(0)) { + typ := TypeOf(f) + if typ.NumIn() == 2 && typ.In(0) == TypeOf(int(0)) { sl := typ.In(1) if sl.Kind() == Slice { - if sl.Elem() == Typeof(0.0) { + if sl.Elem() == TypeOf(0.0) { // ok return } @@ -1304,8 +1317,8 @@ func (*inner) m() {} func (*outer) m() {} func TestNestedMethods(t *testing.T) { - typ := Typeof((*outer)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != NewValue((*outer).m).Pointer() { + typ := TypeOf((*outer)(nil)) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).m).Pointer() { t.Errorf("Wrong method table for outer: (m=%p)", (*outer).m) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) @@ -1314,40 +1327,40 @@ func TestNestedMethods(t *testing.T) { } } -type innerInt struct { - x int +type InnerInt struct { + X int } -type outerInt struct { - y int - innerInt +type OuterInt struct { + Y int + InnerInt } -func (i *innerInt) m() int { - return i.x +func (i *InnerInt) M() int { + return i.X } func TestEmbeddedMethods(t *testing.T) { - typ := Typeof((*outerInt)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != NewValue((*outerInt).m).Pointer() { - t.Errorf("Wrong method table for outerInt: (m=%p)", (*outerInt).m) + typ := TypeOf((*OuterInt)(nil)) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*OuterInt).M).Pointer() { + t.Errorf("Wrong method table for OuterInt: (m=%p)", (*OuterInt).M) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Pointer()) } } - i := &innerInt{3} - if v := NewValue(i).Method(0).Call(nil)[0].Int(); v != 3 { - t.Errorf("i.m() = %d, want 3", v) + i := &InnerInt{3} + if v := ValueOf(i).Method(0).Call(nil)[0].Int(); v != 3 { + t.Errorf("i.M() = %d, want 3", v) } - o := &outerInt{1, innerInt{2}} - if v := NewValue(o).Method(0).Call(nil)[0].Int(); v != 2 { - t.Errorf("i.m() = %d, want 2", v) + o := &OuterInt{1, InnerInt{2}} + if v := ValueOf(o).Method(0).Call(nil)[0].Int(); v != 2 { + t.Errorf("i.M() = %d, want 2", v) } - f := (*outerInt).m + f := (*OuterInt).M if v := f(o); v != 2 { t.Errorf("f(o) = %d, want 2", v) } @@ -1356,15 +1369,15 @@ func TestEmbeddedMethods(t *testing.T) { func TestPtrTo(t *testing.T) { var i int - typ := Typeof(i) + typ := TypeOf(i) for i = 0; i < 100; i++ { typ = PtrTo(typ) } for i = 0; i < 100; i++ { typ = typ.Elem() } - if typ != Typeof(i) { - t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, Typeof(i)) + if typ != TypeOf(i) { + t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, TypeOf(i)) } } @@ -1373,7 +1386,7 @@ func TestAddr(t *testing.T) { X, Y int } - v := NewValue(&p) + v := ValueOf(&p) v = v.Elem() v = v.Addr() v = v.Elem() @@ -1383,9 +1396,10 @@ func TestAddr(t *testing.T) { t.Errorf("Addr.Elem.Set failed to set value") } - // Again but take address of the NewValue value. + // Again but take address of the ValueOf value. // Exercises generation of PtrTypes not present in the binary. - v = NewValue(&p) + q := &p + v = ValueOf(&q).Elem() v = v.Addr() v = v.Elem() v = v.Elem() @@ -1399,7 +1413,8 @@ func TestAddr(t *testing.T) { // Starting without pointer we should get changed value // in interface. - v = NewValue(p) + qq := p + v = ValueOf(&qq).Elem() v0 := v v = v.Addr() v = v.Elem() @@ -1415,3 +1430,67 @@ func TestAddr(t *testing.T) { t.Errorf("Addr.Elem.Set valued to set value in top value") } } + +func noAlloc(t *testing.T, n int, f func(int)) { + // once to prime everything + f(-1) + runtime.MemStats.Mallocs = 0 + + for j := 0; j < n; j++ { + f(j) + } + if runtime.MemStats.Mallocs != 0 { + t.Fatalf("%d mallocs after %d iterations", runtime.MemStats.Mallocs, n) + } +} + +func TestAllocations(t *testing.T) { + noAlloc(t, 100, func(j int) { + var i interface{} + var v Value + i = 42 + j + v = ValueOf(i) + if int(v.Int()) != 42+j { + panic("wrong int") + } + }) +} + +func TestSmallNegativeInt(t *testing.T) { + i := int16(-1) + v := ValueOf(i) + if v.Int() != -1 { + t.Errorf("int16(-1).Int() returned %v", v.Int()) + } +} + +func TestSlice(t *testing.T) { + xs := []int{1, 2, 3, 4, 5, 6, 7, 8} + v := ValueOf(xs).Slice(3, 5).Interface().([]int) + if len(v) != 2 || v[0] != 4 || v[1] != 5 { + t.Errorf("xs.Slice(3, 5) = %v", v) + } + + xa := [7]int{10, 20, 30, 40, 50, 60, 70} + v = ValueOf(&xa).Elem().Slice(2, 5).Interface().([]int) + if len(v) != 3 || v[0] != 30 || v[1] != 40 || v[2] != 50 { + t.Errorf("xa.Slice(2, 5) = %v", v) + } +} + +func TestVariadic(t *testing.T) { + var b bytes.Buffer + V := ValueOf + + b.Reset() + V(fmt.Fprintf).Call([]Value{V(&b), V("%s, %d world"), V("hello"), V(42)}) + if b.String() != "hello, 42 world" { + t.Errorf("after Fprintf Call: %q != %q", b.String(), "hello 42 world") + } + + b.Reset() + V(fmt.Fprintf).CallSlice([]Value{V(&b), V("%s, %d world"), V([]interface{}{"hello", 42})}) + if b.String() != "hello, 42 world" { + t.Errorf("after Fprintf CallSlice: %q != %q", b.String(), "hello 42 world") + } +} diff --git a/src/pkg/reflect/deepequal.go b/src/pkg/reflect/deepequal.go index f5a781460..a483135b0 100644 --- a/src/pkg/reflect/deepequal.go +++ b/src/pkg/reflect/deepequal.go @@ -6,7 +6,6 @@ package reflect - // During deepValueEqual, must keep track of checks that are // in progress. The comparison algorithm assumes that all // checks in progress are true when it reencounters them. @@ -21,7 +20,7 @@ type visit struct { // Tests for deep equality using reflected types. The map argument tracks // comparisons that have already been seen, which allows short circuiting on // recursive types. -func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { +func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) (b bool) { if !v1.IsValid() || !v2.IsValid() { return v1.IsValid() == v2.IsValid() } @@ -31,30 +30,32 @@ func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { // if depth > 10 { panic("deepValueEqual") } // for debugging - addr1 := v1.UnsafeAddr() - addr2 := v2.UnsafeAddr() - if addr1 > addr2 { - // Canonicalize order to reduce number of entries in visited. - addr1, addr2 = addr2, addr1 - } - - // Short circuit if references are identical ... - if addr1 == addr2 { - return true - } + if v1.CanAddr() && v2.CanAddr() { + addr1 := v1.UnsafeAddr() + addr2 := v2.UnsafeAddr() + if addr1 > addr2 { + // Canonicalize order to reduce number of entries in visited. + addr1, addr2 = addr2, addr1 + } - // ... or already seen - h := 17*addr1 + addr2 - seen := visited[h] - typ := v1.Type() - for p := seen; p != nil; p = p.next { - if p.a1 == addr1 && p.a2 == addr2 && p.typ == typ { + // Short circuit if references are identical ... + if addr1 == addr2 { return true } - } - // Remember for later. - visited[h] = &visit{addr1, addr2, typ, seen} + // ... or already seen + h := 17*addr1 + addr2 + seen := visited[h] + typ := v1.Type() + for p := seen; p != nil; p = p.next { + if p.a1 == addr1 && p.a2 == addr2 && p.typ == typ { + return true + } + } + + // Remember for later. + visited[h] = &visit{addr1, addr2, typ, seen} + } switch v1.Kind() { case Array: @@ -116,8 +117,8 @@ func DeepEqual(a1, a2 interface{}) bool { if a1 == nil || a2 == nil { return a1 == a2 } - v1 := NewValue(a1) - v2 := NewValue(a2) + v1 := ValueOf(a1) + v2 := ValueOf(a2) if v1.Type() != v2.Type() { return false } diff --git a/src/pkg/reflect/set_test.go b/src/pkg/reflect/set_test.go new file mode 100644 index 000000000..8135a4cd1 --- /dev/null +++ b/src/pkg/reflect/set_test.go @@ -0,0 +1,211 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflect_test + +import ( + "bytes" + "go/ast" + "io" + . "reflect" + "testing" + "unsafe" +) + +type MyBuffer bytes.Buffer + +func TestImplicitMapConversion(t *testing.T) { + // Test implicit conversions in MapIndex and SetMapIndex. + { + // direct + m := make(map[int]int) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#1 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#1 MapIndex(1) = %d", n) + } + } + { + // convert interface key + m := make(map[interface{}]int) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#2 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#2 MapIndex(1) = %d", n) + } + } + { + // convert interface value + m := make(map[int]interface{}) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#3 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#3 MapIndex(1) = %d", n) + } + } + { + // convert both interface key and interface value + m := make(map[interface{}]interface{}) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#4 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#4 MapIndex(1) = %d", n) + } + } + { + // convert both, with non-empty interfaces + m := make(map[io.Reader]io.Writer) + mv := ValueOf(m) + b1 := new(bytes.Buffer) + b2 := new(bytes.Buffer) + mv.SetMapIndex(ValueOf(b1), ValueOf(b2)) + x, ok := m[b1] + if x != b2 { + t.Errorf("#5 after SetMapIndex(b1, b2): %p (!= %p), %t (map=%v)", x, b2, ok, m) + } + if p := mv.MapIndex(ValueOf(b1)).Elem().Pointer(); p != uintptr(unsafe.Pointer(b2)) { + t.Errorf("#5 MapIndex(b1) = %p want %p", p, b2) + } + } + { + // convert channel direction + m := make(map[<-chan int]chan int) + mv := ValueOf(m) + c1 := make(chan int) + c2 := make(chan int) + mv.SetMapIndex(ValueOf(c1), ValueOf(c2)) + x, ok := m[c1] + if x != c2 { + t.Errorf("#6 after SetMapIndex(c1, c2): %p (!= %p), %t (map=%v)", x, c2, ok, m) + } + if p := mv.MapIndex(ValueOf(c1)).Pointer(); p != ValueOf(c2).Pointer() { + t.Errorf("#6 MapIndex(c1) = %p want %p", p, c2) + } + } + { + // convert identical underlying types + // TODO(rsc): Should be able to define MyBuffer here. + // 6l prints very strange messages about .this.Bytes etc + // when we do that though, so MyBuffer is defined + // at top level. + m := make(map[*MyBuffer]*bytes.Buffer) + mv := ValueOf(m) + b1 := new(MyBuffer) + b2 := new(bytes.Buffer) + mv.SetMapIndex(ValueOf(b1), ValueOf(b2)) + x, ok := m[b1] + if x != b2 { + t.Errorf("#7 after SetMapIndex(b1, b2): %p (!= %p), %t (map=%v)", x, b2, ok, m) + } + if p := mv.MapIndex(ValueOf(b1)).Pointer(); p != uintptr(unsafe.Pointer(b2)) { + t.Errorf("#7 MapIndex(b1) = %p want %p", p, b2) + } + } + +} + +func TestImplicitSetConversion(t *testing.T) { + // Assume TestImplicitMapConversion covered the basics. + // Just make sure conversions are being applied at all. + var r io.Reader + b := new(bytes.Buffer) + rv := ValueOf(&r).Elem() + rv.Set(ValueOf(b)) + if r != b { + t.Errorf("after Set: r=%T(%v)", r, r) + } +} + +func TestImplicitSendConversion(t *testing.T) { + c := make(chan io.Reader, 10) + b := new(bytes.Buffer) + ValueOf(c).Send(ValueOf(b)) + if bb := <-c; bb != b { + t.Errorf("Received %p != %p", bb, b) + } +} + +func TestImplicitCallConversion(t *testing.T) { + // Arguments must be assignable to parameter types. + fv := ValueOf(io.WriteString) + b := new(bytes.Buffer) + fv.Call([]Value{ValueOf(b), ValueOf("hello world")}) + if b.String() != "hello world" { + t.Errorf("After call: string=%q want %q", b.String(), "hello world") + } +} + +func TestImplicitAppendConversion(t *testing.T) { + // Arguments must be assignable to the slice's element type. + s := []io.Reader{} + sv := ValueOf(&s).Elem() + b := new(bytes.Buffer) + sv.Set(Append(sv, ValueOf(b))) + if len(s) != 1 || s[0] != b { + t.Errorf("after append: s=%v want [%p]", s, b) + } +} + +var implementsTests = []struct { + x interface{} + t interface{} + b bool +}{ + {new(*bytes.Buffer), new(io.Reader), true}, + {new(bytes.Buffer), new(io.Reader), false}, + {new(*bytes.Buffer), new(io.ReaderAt), false}, + {new(*ast.Ident), new(ast.Expr), true}, +} + +func TestImplements(t *testing.T) { + for _, tt := range implementsTests { + xv := TypeOf(tt.x).Elem() + xt := TypeOf(tt.t).Elem() + if b := xv.Implements(xt); b != tt.b { + t.Errorf("(%s).Implements(%s) = %v, want %v", xv.String(), xt.String(), b, tt.b) + } + } +} + +var assignableTests = []struct { + x interface{} + t interface{} + b bool +}{ + {new(chan int), new(<-chan int), true}, + {new(<-chan int), new(chan int), false}, + {new(*int), new(IntPtr), true}, + {new(IntPtr), new(*int), true}, + {new(IntPtr), new(IntPtr1), false}, + // test runs implementsTests too +} + +type IntPtr *int +type IntPtr1 *int + +func TestAssignableTo(t *testing.T) { + for _, tt := range append(assignableTests, implementsTests...) { + xv := TypeOf(tt.x).Elem() + xt := TypeOf(tt.t).Elem() + if b := xv.AssignableTo(xt); b != tt.b { + t.Errorf("(%s).AssignableTo(%s) = %v, want %v", xv.String(), xt.String(), b, tt.b) + } + } +} diff --git a/src/pkg/reflect/type.go b/src/pkg/reflect/type.go index 9f3e0bf68..aef6370db 100644 --- a/src/pkg/reflect/type.go +++ b/src/pkg/reflect/type.go @@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The reflect package implements run-time reflection, allowing a program to -// manipulate objects with arbitrary types. The typical use is to take a -// value with static type interface{} and extract its dynamic type -// information by calling Typeof, which returns a Type. +// Package reflect implements run-time reflection, allowing a program to +// manipulate objects with arbitrary types. The typical use is to take a value +// with static type interface{} and extract its dynamic type information by +// calling TypeOf, which returns a Type. // -// A call to NewValue returns a Value representing the run-time data. +// A call to ValueOf returns a Value representing the run-time data. // Zero takes a Type and returns a Value representing a zero value // for that type. package reflect @@ -47,7 +47,7 @@ type Type interface { // method signature, without a receiver, and the Func field is nil. Method(int) Method - // NumMethods returns the number of methods in the type's method set. + // NumMethod returns the number of methods in the type's method set. NumMethod() int // Name returns the type's name within its package. @@ -73,6 +73,12 @@ type Type interface { // Kind returns the specific kind of this type. Kind() Kind + // Implements returns true if the type implements the interface type u. + Implements(u Type) bool + + // AssignableTo returns true if a value of the type is assignable to type u. + AssignableTo(u Type) bool + // Methods applicable only to some types, depending on Kind. // The methods allowed for each kind are: // @@ -162,6 +168,8 @@ type Type interface { // It panics if i is not in the range [0, NumOut()). Out(i int) Type + runtimeType() *runtime.Type + common() *commonType uncommon() *uncommonType } @@ -258,6 +266,7 @@ const ( type arrayType struct { commonType "array" elem *runtime.Type + slice *runtime.Type len uintptr } @@ -408,9 +417,12 @@ func (t *commonType) String() string { return *t.string } func (t *commonType) Size() uintptr { return t.size } func (t *commonType) Bits() int { + if t == nil { + panic("reflect: Bits of nil Type") + } k := t.Kind() if k < Int || k > Complex128 { - panic("reflect: Bits of non-arithmetic Type") + panic("reflect: Bits of non-arithmetic Type " + t.String()) } return int(t.size) * 8 } @@ -431,12 +443,14 @@ func (t *uncommonType) Method(i int) (m Method) { if p.name != nil { m.Name = *p.name } + flag := uint32(0) if p.pkgPath != nil { m.PkgPath = *p.pkgPath + flag |= flagRO } m.Type = toType(p.typ) fn := p.tfn - m.Func = Value{&funcValue{value: value{m.Type, addr(&fn), canSet}}} + m.Func = valueFromIword(flag, m.Type, iword(fn)) return } @@ -772,24 +786,32 @@ func (t *structType) FieldByNameFunc(match func(string) bool) (f StructField, pr } // Convert runtime type to reflect type. -func toType(p *runtime.Type) Type { +func toCommonType(p *runtime.Type) *commonType { + if p == nil { + return nil + } type hdr struct { x interface{} t commonType } - t := &(*hdr)(unsafe.Pointer(p)).t - return t.toType() + x := unsafe.Pointer(p) + if uintptr(x)&reflectFlags != 0 { + panic("invalid interface value") + } + return &(*hdr)(x).t } -// Typeof returns the reflection Type of the value in the interface{}. -func Typeof(i interface{}) Type { - type hdr struct { - typ *byte - val *commonType +func toType(p *runtime.Type) Type { + if p == nil { + return nil } - rt := unsafe.Typeof(i) - t := (*(*hdr)(unsafe.Pointer(&rt))).val - return t.toType() + return toCommonType(p).toType() +} + +// TypeOf returns the reflection Type of the value in the interface{}. +func TypeOf(i interface{}) Type { + eface := *(*emptyInterface)(unsafe.Pointer(&i)) + return toType(eface.typ) } // ptrMap is the cache for PtrTo. @@ -798,6 +820,16 @@ var ptrMap struct { m map[*commonType]*ptrType } +func (t *commonType) runtimeType() *runtime.Type { + // The runtime.Type always precedes the commonType in memory. + // Adjust pointer to find it. + var rt struct { + i runtime.Type + ct commonType + } + return (*runtime.Type)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) - uintptr(unsafe.Offsetof(rt.ct)))) +} + // PtrTo returns the pointer type with element t. // For example, if t represents type Foo, PtrTo(t) represents *Foo. func PtrTo(t Type) Type { @@ -862,3 +894,164 @@ func PtrTo(t Type) Type { ptrMap.Unlock() return p.commonType.toType() } + +func (t *commonType) Implements(u Type) bool { + if u == nil { + panic("reflect: nil type passed to Type.Implements") + } + if u.Kind() != Interface { + panic("reflect: non-interface type passed to Type.Implements") + } + return implements(u.(*commonType), t) +} + +func (t *commonType) AssignableTo(u Type) bool { + if u == nil { + panic("reflect: nil type passed to Type.AssignableTo") + } + uu := u.(*commonType) + return directlyAssignable(uu, t) || implements(uu, t) +} + +// implements returns true if the type V implements the interface type T. +func implements(T, V *commonType) bool { + if T.Kind() != Interface { + return false + } + t := (*interfaceType)(unsafe.Pointer(T)) + if len(t.methods) == 0 { + return true + } + + // The same algorithm applies in both cases, but the + // method tables for an interface type and a concrete type + // are different, so the code is duplicated. + // In both cases the algorithm is a linear scan over the two + // lists - T's methods and V's methods - simultaneously. + // Since method tables are stored in a unique sorted order + // (alphabetical, with no duplicate method names), the scan + // through V's methods must hit a match for each of T's + // methods along the way, or else V does not implement T. + // This lets us run the scan in overall linear time instead of + // the quadratic time a naive search would require. + // See also ../runtime/iface.c. + if V.Kind() == Interface { + v := (*interfaceType)(unsafe.Pointer(V)) + i := 0 + for j := 0; j < len(v.methods); j++ { + tm := &t.methods[i] + vm := &v.methods[j] + if vm.name == tm.name && vm.pkgPath == tm.pkgPath && vm.typ == tm.typ { + if i++; i >= len(t.methods) { + return true + } + } + } + return false + } + + v := V.uncommon() + if v == nil { + return false + } + i := 0 + for j := 0; j < len(v.methods); j++ { + tm := &t.methods[i] + vm := &v.methods[j] + if vm.name == tm.name && vm.pkgPath == tm.pkgPath && vm.mtyp == tm.typ { + if i++; i >= len(t.methods) { + return true + } + } + } + return false +} + +// directlyAssignable returns true if a value x of type V can be directly +// assigned (using memmove) to a value of type T. +// http://golang.org/doc/go_spec.html#Assignability +// Ignoring the interface rules (implemented elsewhere) +// and the ideal constant rules (no ideal constants at run time). +func directlyAssignable(T, V *commonType) bool { + // x's type V is identical to T? + if T == V { + return true + } + + // Otherwise at least one of T and V must be unnamed + // and they must have the same kind. + if T.Name() != "" && V.Name() != "" || T.Kind() != V.Kind() { + return false + } + + // x's type T and V have identical underlying types. + // Since at least one is unnamed, only the composite types + // need to be considered. + switch T.Kind() { + case Array: + return T.Elem() == V.Elem() && T.Len() == V.Len() + + case Chan: + // Special case: + // x is a bidirectional channel value, T is a channel type, + // and x's type V and T have identical element types. + if V.ChanDir() == BothDir && T.Elem() == V.Elem() { + return true + } + + // Otherwise continue test for identical underlying type. + return V.ChanDir() == T.ChanDir() && T.Elem() == V.Elem() + + case Func: + t := (*funcType)(unsafe.Pointer(T)) + v := (*funcType)(unsafe.Pointer(V)) + if t.dotdotdot != v.dotdotdot || len(t.in) != len(v.in) || len(t.out) != len(v.out) { + return false + } + for i, typ := range t.in { + if typ != v.in[i] { + return false + } + } + for i, typ := range t.out { + if typ != v.out[i] { + return false + } + } + return true + + case Interface: + t := (*interfaceType)(unsafe.Pointer(T)) + v := (*interfaceType)(unsafe.Pointer(V)) + if len(t.methods) == 0 && len(v.methods) == 0 { + return true + } + // Might have the same methods but still + // need a run time conversion. + return false + + case Map: + return T.Key() == V.Key() && T.Elem() == V.Elem() + + case Ptr, Slice: + return T.Elem() == V.Elem() + + case Struct: + t := (*structType)(unsafe.Pointer(T)) + v := (*structType)(unsafe.Pointer(V)) + if len(t.fields) != len(v.fields) { + return false + } + for i := range t.fields { + tf := &t.fields[i] + vf := &v.fields[i] + if tf.name != vf.name || tf.pkgPath != vf.pkgPath || + tf.typ != vf.typ || tf.tag != vf.tag || tf.offset != vf.offset { + return false + } + } + return true + } + + return false +} diff --git a/src/pkg/reflect/value.go b/src/pkg/reflect/value.go index ddc31100f..6dffb0783 100644 --- a/src/pkg/reflect/value.go +++ b/src/pkg/reflect/value.go @@ -7,17 +7,16 @@ package reflect import ( "math" "runtime" + "strconv" "unsafe" ) const ptrSize = uintptr(unsafe.Sizeof((*byte)(nil))) const cannotSet = "cannot set value obtained from unexported struct field" -type addr unsafe.Pointer - // TODO: This will have to go away when // the new gc goes in. -func memmove(adst, asrc addr, n uintptr) { +func memmove(adst, asrc unsafe.Pointer, n uintptr) { dst := uintptr(adst) src := uintptr(asrc) switch { @@ -26,17 +25,17 @@ func memmove(adst, asrc addr, n uintptr) { // careful: i is unsigned for i := n; i > 0; { i-- - *(*byte)(addr(dst + i)) = *(*byte)(addr(src + i)) + *(*byte)(unsafe.Pointer(dst + i)) = *(*byte)(unsafe.Pointer(src + i)) } case (n|src|dst)&(ptrSize-1) != 0: // byte copy forward for i := uintptr(0); i < n; i++ { - *(*byte)(addr(dst + i)) = *(*byte)(addr(src + i)) + *(*byte)(unsafe.Pointer(dst + i)) = *(*byte)(unsafe.Pointer(src + i)) } default: // word copy forward for i := uintptr(0); i < n; i += ptrSize { - *(*uintptr)(addr(dst + i)) = *(*uintptr)(addr(src + i)) + *(*uintptr)(unsafe.Pointer(dst + i)) = *(*uintptr)(unsafe.Pointer(src + i)) } } } @@ -54,15 +53,16 @@ func memmove(adst, asrc addr, n uintptr) { // its String method returns "<invalid Value>", and all other methods panic. // Most functions and methods never return an invalid value. // If one does, its documentation states the conditions explicitly. +// +// The fields of Value are exported so that clients can copy and +// pass Values around, but they should not be edited or inspected +// directly. A future language change may make it possible not to +// export these fields while still keeping Values usable as values. type Value struct { - Internal valueInterface + Internal interface{} + InternalMethod int } -// TODO(rsc): This implementation of Value is a just a façade -// in front of the old implementation, now called valueInterface. -// A future CL will change it to a real implementation. -// Changing the API is already a big enough step for one CL. - // A ValueError occurs when a Value method is invoked on // a Value that does not support it. Such cases are documented // in the description of each method. @@ -89,37 +89,292 @@ func methodName() string { return f.Name() } -func (v Value) internal() valueInterface { - vi := v.Internal - if vi == nil { - panic(&ValueError{methodName(), 0}) +// An iword is the word that would be stored in an +// interface to represent a given value v. Specifically, if v is +// bigger than a pointer, its word is a pointer to v's data. +// Otherwise, its word is a zero uintptr with the data stored +// in the leading bytes. +type iword uintptr + +func loadIword(p unsafe.Pointer, size uintptr) iword { + // Run the copy ourselves instead of calling memmove + // to avoid moving v to the heap. + w := iword(0) + switch size { + default: + panic("reflect: internal error: loadIword of " + strconv.Itoa(int(size)) + "-byte value") + case 0: + case 1: + *(*uint8)(unsafe.Pointer(&w)) = *(*uint8)(p) + case 2: + *(*uint16)(unsafe.Pointer(&w)) = *(*uint16)(p) + case 3: + *(*[3]byte)(unsafe.Pointer(&w)) = *(*[3]byte)(p) + case 4: + *(*uint32)(unsafe.Pointer(&w)) = *(*uint32)(p) + case 5: + *(*[5]byte)(unsafe.Pointer(&w)) = *(*[5]byte)(p) + case 6: + *(*[6]byte)(unsafe.Pointer(&w)) = *(*[6]byte)(p) + case 7: + *(*[7]byte)(unsafe.Pointer(&w)) = *(*[7]byte)(p) + case 8: + *(*uint64)(unsafe.Pointer(&w)) = *(*uint64)(p) + } + return w +} + +func storeIword(p unsafe.Pointer, w iword, size uintptr) { + // Run the copy ourselves instead of calling memmove + // to avoid moving v to the heap. + switch size { + default: + panic("reflect: internal error: storeIword of " + strconv.Itoa(int(size)) + "-byte value") + case 0: + case 1: + *(*uint8)(p) = *(*uint8)(unsafe.Pointer(&w)) + case 2: + *(*uint16)(p) = *(*uint16)(unsafe.Pointer(&w)) + case 3: + *(*[3]byte)(p) = *(*[3]byte)(unsafe.Pointer(&w)) + case 4: + *(*uint32)(p) = *(*uint32)(unsafe.Pointer(&w)) + case 5: + *(*[5]byte)(p) = *(*[5]byte)(unsafe.Pointer(&w)) + case 6: + *(*[6]byte)(p) = *(*[6]byte)(unsafe.Pointer(&w)) + case 7: + *(*[7]byte)(p) = *(*[7]byte)(unsafe.Pointer(&w)) + case 8: + *(*uint64)(p) = *(*uint64)(unsafe.Pointer(&w)) + } +} + +// emptyInterface is the header for an interface{} value. +type emptyInterface struct { + typ *runtime.Type + word iword +} + +// nonEmptyInterface is the header for a interface value with methods. +type nonEmptyInterface struct { + // see ../runtime/iface.c:/Itab + itab *struct { + ityp *runtime.Type // static interface type + typ *runtime.Type // dynamic concrete type + link unsafe.Pointer + bad int32 + unused int32 + fun [100000]unsafe.Pointer // method table + } + word iword +} + +// Regarding the implementation of Value: +// +// The Internal interface is a true interface value in the Go sense, +// but it also serves as a (type, address) pair in whcih one cannot +// be changed separately from the other. That is, it serves as a way +// to prevent unsafe mutations of the Internal state even though +// we cannot (yet?) hide the field while preserving the ability for +// clients to make copies of Values. +// +// The internal method converts a Value into the expanded internalValue struct. +// If we could avoid exporting fields we'd probably make internalValue the +// definition of Value. +// +// If a Value is addressable (CanAddr returns true), then the Internal +// interface value holds a pointer to the actual field data, and Set stores +// through that pointer. If a Value is not addressable (CanAddr returns false), +// then the Internal interface value holds the actual value. +// +// In addition to whether a value is addressable, we track whether it was +// obtained by using an unexported struct field. Such values are allowed +// to be read, mainly to make fmt.Print more useful, but they are not +// allowed to be written. We call such values read-only. +// +// A Value can be set (via the Set, SetUint, etc. methods) only if it is both +// addressable and not read-only. +// +// The two permission bits - addressable and read-only - are stored in +// the bottom two bits of the type pointer in the interface value. +// +// ordinary value: Internal = value +// addressable value: Internal = value, Internal.typ |= flagAddr +// read-only value: Internal = value, Internal.typ |= flagRO +// addressable, read-only value: Internal = value, Internal.typ |= flagAddr | flagRO +// +// It is important that the read-only values have the extra bit set +// (as opposed to using the bit to mean writable), because client code +// can grab the interface field and try to use it. Having the extra bit +// set makes the type pointer compare not equal to any real type, +// so that a client cannot, say, write through v.Internal.(*int). +// The runtime routines that access interface types reject types with +// low bits set. +// +// If a Value fv = v.Method(i), then fv = v with the InternalMethod +// field set to i+1. Methods are never addressable. +// +// All in all, this is a lot of effort just to avoid making this new API +// depend on a language change we'll probably do anyway, but +// it's helpful to keep the two separate, and much of the logic is +// necessary to implement the Interface method anyway. + +const ( + flagAddr uint32 = 1 << iota // holds address of value + flagRO // read-only + + reflectFlags = 3 +) + +// An internalValue is the unpacked form of a Value. +// The zero Value unpacks to a zero internalValue +type internalValue struct { + typ *commonType // type of value + kind Kind // kind of value + flag uint32 + word iword + addr unsafe.Pointer + rcvr iword + method bool + nilmethod bool +} + +func (v Value) internal() internalValue { + var iv internalValue + eface := *(*emptyInterface)(unsafe.Pointer(&v.Internal)) + p := uintptr(unsafe.Pointer(eface.typ)) + iv.typ = toCommonType((*runtime.Type)(unsafe.Pointer(p &^ reflectFlags))) + if iv.typ == nil { + return iv + } + iv.flag = uint32(p & reflectFlags) + iv.word = eface.word + if iv.flag&flagAddr != 0 { + iv.addr = unsafe.Pointer(iv.word) + iv.typ = iv.typ.Elem().common() + if iv.typ.size <= ptrSize { + iv.word = loadIword(iv.addr, iv.typ.size) + } + } else { + if iv.typ.size > ptrSize { + iv.addr = unsafe.Pointer(iv.word) + } } - return vi + iv.kind = iv.typ.Kind() + + // Is this a method? If so, iv describes the receiver. + // Rewrite to describe the method function. + if v.InternalMethod != 0 { + // If this Value is a method value (x.Method(i) for some Value x) + // then we will invoke it using the interface form of the method, + // which always passes the receiver as a single word. + // Record that information. + i := v.InternalMethod - 1 + if iv.kind == Interface { + it := (*interfaceType)(unsafe.Pointer(iv.typ)) + if i < 0 || i >= len(it.methods) { + panic("reflect: broken Value") + } + m := &it.methods[i] + if m.pkgPath != nil { + iv.flag |= flagRO + } + iv.typ = toCommonType(m.typ) + iface := (*nonEmptyInterface)(iv.addr) + if iface.itab == nil { + iv.word = 0 + iv.nilmethod = true + } else { + iv.word = iword(iface.itab.fun[i]) + } + iv.rcvr = iface.word + } else { + ut := iv.typ.uncommon() + if ut == nil || i < 0 || i >= len(ut.methods) { + panic("reflect: broken Value") + } + m := &ut.methods[i] + if m.pkgPath != nil { + iv.flag |= flagRO + } + iv.typ = toCommonType(m.mtyp) + iv.rcvr = iv.word + iv.word = iword(m.ifn) + } + iv.kind = Func + iv.method = true + iv.flag &^= flagAddr + iv.addr = nil + } + + return iv } -func (v Value) panicIfNot(want Kind) valueInterface { - vi := v.Internal - if vi == nil { - panic(&ValueError{methodName(), 0}) +// packValue returns a Value with the given flag bits, type, and interface word. +func packValue(flag uint32, typ *runtime.Type, word iword) Value { + if typ == nil { + panic("packValue") } - if k := vi.Kind(); k != want { - panic(&ValueError{methodName(), k}) + t := uintptr(unsafe.Pointer(typ)) + t |= uintptr(flag) + eface := emptyInterface{(*runtime.Type)(unsafe.Pointer(t)), word} + return Value{Internal: *(*interface{})(unsafe.Pointer(&eface))} +} + +// valueFromAddr returns a Value using the given type and address. +func valueFromAddr(flag uint32, typ Type, addr unsafe.Pointer) Value { + if flag&flagAddr != 0 { + // Addressable, so the internal value is + // an interface containing a pointer to the real value. + return packValue(flag, PtrTo(typ).runtimeType(), iword(addr)) } - return vi + + var w iword + if n := typ.Size(); n <= ptrSize { + // In line, so the interface word is the actual value. + w = loadIword(addr, n) + } else { + // Not in line: the interface word is the address. + w = iword(addr) + } + return packValue(flag, typ.runtimeType(), w) } -func (v Value) panicIfNots(wants []Kind) valueInterface { - vi := v.Internal - if vi == nil { - panic(&ValueError{methodName(), 0}) +// valueFromIword returns a Value using the given type and interface word. +func valueFromIword(flag uint32, typ Type, w iword) Value { + if flag&flagAddr != 0 { + panic("reflect: internal error: valueFromIword addressable") } - k := vi.Kind() - for _, want := range wants { - if k == want { - return vi - } + return packValue(flag, typ.runtimeType(), w) +} + +func (iv internalValue) mustBe(want Kind) { + if iv.kind != want { + panic(&ValueError{methodName(), iv.kind}) + } +} + +func (iv internalValue) mustBeExported() { + if iv.kind == 0 { + panic(&ValueError{methodName(), iv.kind}) + } + if iv.flag&flagRO != 0 { + panic(methodName() + " using value obtained using unexported field") + } +} + +func (iv internalValue) mustBeAssignable() { + if iv.kind == 0 { + panic(&ValueError{methodName(), iv.kind}) + } + // Assignable if addressable and not read-only. + if iv.flag&flagRO != 0 { + panic(methodName() + " using value obtained using unexported field") + } + if iv.flag&flagAddr == 0 { + panic(methodName() + " using unaddressable value") } - panic(&ValueError{methodName(), k}) } // Addr returns a pointer value representing the address of v. @@ -128,56 +383,142 @@ func (v Value) panicIfNots(wants []Kind) valueInterface { // or slice element in order to call a method that requires a // pointer receiver. func (v Value) Addr() Value { - return v.internal().Addr() + iv := v.internal() + if iv.flag&flagAddr == 0 { + panic("reflect.Value.Addr of unaddressable value") + } + return valueFromIword(iv.flag&flagRO, PtrTo(iv.typ.toType()), iword(iv.addr)) } // Bool returns v's underlying value. // It panics if v's kind is not Bool. func (v Value) Bool() bool { - u := v.panicIfNot(Bool).(*boolValue) - return *(*bool)(u.addr) + iv := v.internal() + iv.mustBe(Bool) + return *(*bool)(unsafe.Pointer(&iv.word)) } // CanAddr returns true if the value's address can be obtained with Addr. // Such values are called addressable. A value is addressable if it is // an element of a slice, an element of an addressable array, -// a field of an addressable struct, the result of dereferencing a pointer, -// or the result of a call to NewValue, MakeChan, MakeMap, or Zero. +// a field of an addressable struct, or the result of dereferencing a pointer. // If CanAddr returns false, calling Addr will panic. func (v Value) CanAddr() bool { - return v.internal().CanAddr() + iv := v.internal() + return iv.flag&flagAddr != 0 } // CanSet returns true if the value of v can be changed. -// Values obtained by the use of unexported struct fields -// can be read but not set. +// A Value can be changed only if it is addressable and was not +// obtained by the use of unexported struct fields. // If CanSet returns false, calling Set or any type-specific // setter (e.g., SetBool, SetInt64) will panic. func (v Value) CanSet() bool { - return v.internal().CanSet() -} - -// Call calls the function v with the input parameters in. -// It panics if v's Kind is not Func. -// It returns the output parameters as Values. + iv := v.internal() + return iv.flag&(flagAddr|flagRO) == flagAddr +} + +// Call calls the function v with the input arguments in. +// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]). +// Call panics if v's Kind is not Func. +// It returns the output results as Values. +// As in Go, each input argument must be assignable to the +// type of the function's corresponding input parameter. +// If v is a variadic function, Call creates the variadic slice parameter +// itself, copying in the corresponding values. func (v Value) Call(in []Value) []Value { - fv := v.panicIfNot(Func).(*funcValue) - t := fv.Type() - nin := len(in) - if fv.first != nil && !fv.isInterface { - nin++ + iv := v.internal() + iv.mustBe(Func) + iv.mustBeExported() + return iv.call("Call", in) +} + +// CallSlice calls the variadic function v with the input arguments in, +// assigning the slice in[len(in)-1] to v's final variadic argument. +// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]...). +// Call panics if v's Kind is not Func or if v is not variadic. +// It returns the output results as Values. +// As in Go, each input argument must be assignable to the +// type of the function's corresponding input parameter. +func (v Value) CallSlice(in []Value) []Value { + iv := v.internal() + iv.mustBe(Func) + iv.mustBeExported() + return iv.call("CallSlice", in) +} + +func (iv internalValue) call(method string, in []Value) []Value { + if iv.word == 0 { + if iv.nilmethod { + panic("reflect.Value.Call: call of method on nil interface value") + } + panic("reflect.Value.Call: call of nil function") + } + + isSlice := method == "CallSlice" + t := iv.typ + n := t.NumIn() + if isSlice { + if !t.IsVariadic() { + panic("reflect: CallSlice of non-variadic function") + } + if len(in) < n { + panic("reflect: CallSlice with too few input arguments") + } + if len(in) > n { + panic("reflect: CallSlice with too many input arguments") + } + } else { + if t.IsVariadic() { + n-- + } + if len(in) < n { + panic("reflect: Call with too few input arguments") + } + if !t.IsVariadic() && len(in) > n { + panic("reflect: Call with too many input arguments") + } + } + for _, x := range in { + if x.Kind() == Invalid { + panic("reflect: " + method + " using zero Value argument") + } } + for i := 0; i < n; i++ { + if xt, targ := in[i].Type(), t.In(i); !xt.AssignableTo(targ) { + panic("reflect: " + method + " using " + xt.String() + " as type " + targ.String()) + } + } + if !isSlice && t.IsVariadic() { + // prepare slice for remaining values + m := len(in) - n + slice := MakeSlice(t.In(n), m, m) + elem := t.In(n).Elem() + for i := 0; i < m; i++ { + x := in[n+i] + if xt := x.Type(); !xt.AssignableTo(elem) { + panic("reflect: cannot use " + xt.String() + " as type " + elem.String() + " in " + method) + } + slice.Index(i).Set(x) + } + origIn := in + in = make([]Value, n+1) + copy(in[:n], origIn) + in[n] = slice + } + + nin := len(in) if nin != t.NumIn() { - panic("funcValue: wrong argument count") + panic("reflect.Value.Call: wrong argument count") } nout := t.NumOut() // Compute arg size & allocate. - // This computation is 6g/8g-dependent + // This computation is 5g/6g/8g-dependent // and probably wrong for gccgo, but so // is most of this function. size := uintptr(0) - if fv.isInterface { + if iv.method { // extra word for interface value size += ptrSize } @@ -215,36 +556,31 @@ func (v Value) Call(in []Value) []Value { args := make([]*int, size/ptrSize) ptr := uintptr(unsafe.Pointer(&args[0])) off := uintptr(0) - delta := 0 - if v := fv.first; v != nil { + if iv.method { // Hard-wired first argument. - if fv.isInterface { - // v is a single uninterpreted word - memmove(addr(ptr), v.getAddr(), ptrSize) - off = ptrSize - } else { - // v is a real value - tv := v.Type() - typesMustMatch(t.In(0), tv) - n := tv.Size() - memmove(addr(ptr), v.getAddr(), n) - off = n - delta = 1 - } + *(*iword)(unsafe.Pointer(ptr)) = iv.rcvr + off = ptrSize } for i, v := range in { - tv := v.Type() - typesMustMatch(t.In(i+delta), tv) - a := uintptr(tv.Align()) + iv := v.internal() + iv.mustBeExported() + targ := t.In(i).(*commonType) + a := uintptr(targ.align) off = (off + a - 1) &^ (a - 1) - n := tv.Size() - memmove(addr(ptr+off), v.internal().getAddr(), n) + n := targ.size + addr := unsafe.Pointer(ptr + off) + iv = convertForAssignment("reflect.Value.Call", addr, targ, iv) + if iv.addr == nil { + storeIword(addr, iv.word, n) + } else { + memmove(addr, iv.addr, n) + } off += n } off = (off + ptrSize - 1) &^ (ptrSize - 1) - // Call - call(*(**byte)(fv.addr), (*byte)(addr(ptr)), uint32(size)) + // Call. + call(unsafe.Pointer(iv.word), unsafe.Pointer(ptr), uint32(size)) // Copy return values out of args. // @@ -254,111 +590,148 @@ func (v Value) Call(in []Value) []Value { tv := t.Out(i) a := uintptr(tv.Align()) off = (off + a - 1) &^ (a - 1) - v := Zero(tv) - n := tv.Size() - memmove(v.internal().getAddr(), addr(ptr+off), n) - ret[i] = v - off += n + ret[i] = valueFromAddr(0, tv, unsafe.Pointer(ptr+off)) + off += tv.Size() } return ret } -var capKinds = []Kind{Array, Chan, Slice} - // Cap returns v's capacity. // It panics if v's Kind is not Array, Chan, or Slice. func (v Value) Cap() int { - switch vv := v.panicIfNots(capKinds).(type) { - case *arrayValue: - return vv.typ.Len() - case *chanValue: - ch := *(**byte)(vv.addr) - return int(chancap(ch)) - case *sliceValue: - return int(vv.slice().Cap) + iv := v.internal() + switch iv.kind { + case Array: + return iv.typ.Len() + case Chan: + return int(chancap(iv.word)) + case Slice: + return (*SliceHeader)(iv.addr).Cap } - panic("not reached") + panic(&ValueError{"reflect.Value.Cap", iv.kind}) } // Close closes the channel v. // It panics if v's Kind is not Chan. func (v Value) Close() { - vv := v.panicIfNot(Chan).(*chanValue) - - ch := *(**byte)(vv.addr) + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + ch := iv.word chanclose(ch) } -var complexKinds = []Kind{Complex64, Complex128} - // Complex returns v's underlying value, as a complex128. // It panics if v's Kind is not Complex64 or Complex128 func (v Value) Complex() complex128 { - vv := v.panicIfNots(complexKinds).(*complexValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Complex64: - return complex128(*(*complex64)(vv.addr)) + if iv.addr == nil { + return complex128(*(*complex64)(unsafe.Pointer(&iv.word))) + } + return complex128(*(*complex64)(iv.addr)) case Complex128: - return *(*complex128)(vv.addr) + return *(*complex128)(iv.addr) } - panic("reflect: invalid complex kind") + panic(&ValueError{"reflect.Value.Complex", iv.kind}) } -var interfaceOrPtr = []Kind{Interface, Ptr} - // Elem returns the value that the interface v contains // or that the pointer v points to. // It panics if v's Kind is not Interface or Ptr. // It returns the zero Value if v is nil. func (v Value) Elem() Value { - switch vv := v.panicIfNots(interfaceOrPtr).(type) { - case *interfaceValue: - return NewValue(vv.Interface()) - case *ptrValue: - if v.IsNil() { + iv := v.internal() + return iv.Elem() +} + +func (iv internalValue) Elem() Value { + switch iv.kind { + case Interface: + // Empty interface and non-empty interface have different layouts. + // Convert to empty interface. + var eface emptyInterface + if iv.typ.NumMethod() == 0 { + eface = *(*emptyInterface)(iv.addr) + } else { + iface := (*nonEmptyInterface)(iv.addr) + if iface.itab != nil { + eface.typ = iface.itab.typ + } + eface.word = iface.word + } + if eface.typ == nil { return Value{} } - flag := canAddr - if vv.flag&canStore != 0 { - flag |= canSet | canStore + return valueFromIword(iv.flag&flagRO, toType(eface.typ), eface.word) + + case Ptr: + // The returned value's address is v's value. + if iv.word == 0 { + return Value{} } - return newValue(vv.typ.Elem(), *(*addr)(vv.addr), flag) + return valueFromAddr(iv.flag&flagRO|flagAddr, iv.typ.Elem(), unsafe.Pointer(iv.word)) } - panic("not reached") + panic(&ValueError{"reflect.Value.Elem", iv.kind}) } // Field returns the i'th field of the struct v. -// It panics if v's Kind is not Struct. +// It panics if v's Kind is not Struct or i is out of range. func (v Value) Field(i int) Value { - vv := v.panicIfNot(Struct).(*structValue) - - t := vv.typ + iv := v.internal() + iv.mustBe(Struct) + t := iv.typ.toType() if i < 0 || i >= t.NumField() { panic("reflect: Field index out of range") } f := t.Field(i) - flag := vv.flag + + // Inherit permission bits from v. + flag := iv.flag + // Using an unexported field forces flagRO. if f.PkgPath != "" { - // unexported field - flag &^= canSet | canStore + flag |= flagRO } - return newValue(f.Type, addr(uintptr(vv.addr)+f.Offset), flag) + return valueFromValueOffset(flag, f.Type, iv, f.Offset) +} + +// valueFromValueOffset returns a sub-value of outer +// (outer is an array or a struct) with the given flag and type +// starting at the given byte offset into outer. +func valueFromValueOffset(flag uint32, typ Type, outer internalValue, offset uintptr) Value { + if outer.addr != nil { + return valueFromAddr(flag, typ, unsafe.Pointer(uintptr(outer.addr)+offset)) + } + + // outer is so tiny it is in line. + // We have to use outer.word and derive + // the new word (it cannot possibly be bigger). + // In line, so not addressable. + if flag&flagAddr != 0 { + panic("reflect: internal error: misuse of valueFromValueOffset") + } + b := *(*[ptrSize]byte)(unsafe.Pointer(&outer.word)) + for i := uintptr(0); i < typ.Size(); i++ { + b[i] = b[offset+i] + } + for i := typ.Size(); i < ptrSize; i++ { + b[i] = 0 + } + w := *(*iword)(unsafe.Pointer(&b)) + return valueFromIword(flag, typ, w) } // FieldByIndex returns the nested field corresponding to index. // It panics if v's Kind is not struct. func (v Value) FieldByIndex(index []int) Value { - v.panicIfNot(Struct) + v.internal().mustBe(Struct) for i, x := range index { if i > 0 { - if v.Kind() == Ptr { + if v.Kind() == Ptr && v.Elem().Kind() == Struct { v = v.Elem() } - if v.Kind() != Struct { - return Value{} - } } v = v.Field(x) } @@ -369,7 +742,9 @@ func (v Value) FieldByIndex(index []int) Value { // It returns the zero Value if no field was found. // It panics if v's Kind is not struct. func (v Value) FieldByName(name string) Value { - if f, ok := v.Type().FieldByName(name); ok { + iv := v.internal() + iv.mustBe(Struct) + if f, ok := iv.typ.FieldByName(name); ok { return v.FieldByIndex(f.Index) } return Value{} @@ -380,79 +755,100 @@ func (v Value) FieldByName(name string) Value { // It panics if v's Kind is not struct. // It returns the zero Value if no field was found. func (v Value) FieldByNameFunc(match func(string) bool) Value { + v.internal().mustBe(Struct) if f, ok := v.Type().FieldByNameFunc(match); ok { return v.FieldByIndex(f.Index) } return Value{} } -var floatKinds = []Kind{Float32, Float64} - // Float returns v's underlying value, as an float64. // It panics if v's Kind is not Float32 or Float64 func (v Value) Float() float64 { - vv := v.panicIfNots(floatKinds).(*floatValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Float32: - return float64(*(*float32)(vv.addr)) + return float64(*(*float32)(unsafe.Pointer(&iv.word))) case Float64: - return *(*float64)(vv.addr) + // If the pointer width can fit an entire float64, + // the value is in line when stored in an interface. + if iv.addr == nil { + return *(*float64)(unsafe.Pointer(&iv.word)) + } + // Otherwise we have a pointer. + return *(*float64)(iv.addr) } - panic("reflect: invalid float kind") - + panic(&ValueError{"reflect.Value.Float", iv.kind}) } -var arrayOrSlice = []Kind{Array, Slice} - // Index returns v's i'th element. -// It panics if v's Kind is not Array or Slice. +// It panics if v's Kind is not Array or Slice or i is out of range. func (v Value) Index(i int) Value { - switch vv := v.panicIfNots(arrayOrSlice).(type) { - case *arrayValue: - typ := vv.typ.Elem() - n := v.Len() - if i < 0 || i >= n { - panic("array index out of bounds") + iv := v.internal() + switch iv.kind { + default: + panic(&ValueError{"reflect.Value.Index", iv.kind}) + case Array: + flag := iv.flag // element flag same as overall array + t := iv.typ.toType() + if i < 0 || i > t.Len() { + panic("reflect: array index out of range") } - p := addr(uintptr(vv.addr()) + uintptr(i)*typ.Size()) - return newValue(typ, p, vv.flag) - case *sliceValue: - typ := vv.typ.Elem() - n := v.Len() - if i < 0 || i >= n { + typ := t.Elem() + return valueFromValueOffset(flag, typ, iv, uintptr(i)*typ.Size()) + + case Slice: + // Element flag same as Elem of Ptr. + // Addressable, possibly read-only. + flag := iv.flag&flagRO | flagAddr + s := (*SliceHeader)(iv.addr) + if i < 0 || i >= s.Len { panic("reflect: slice index out of range") } - p := addr(uintptr(vv.addr()) + uintptr(i)*typ.Size()) - flag := canAddr - if vv.flag&canStore != 0 { - flag |= canSet | canStore - } - return newValue(typ, p, flag) + typ := iv.typ.Elem() + addr := unsafe.Pointer(s.Data + uintptr(i)*typ.Size()) + return valueFromAddr(flag, typ, addr) } + panic("not reached") } -var intKinds = []Kind{Int, Int8, Int16, Int32, Int64} - // Int returns v's underlying value, as an int64. -// It panics if v's Kind is not a sized or unsized Int kind. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64. func (v Value) Int() int64 { - vv := v.panicIfNots(intKinds).(*intValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Int: - return int64(*(*int)(vv.addr)) + return int64(*(*int)(unsafe.Pointer(&iv.word))) case Int8: - return int64(*(*int8)(vv.addr)) + return int64(*(*int8)(unsafe.Pointer(&iv.word))) case Int16: - return int64(*(*int16)(vv.addr)) + return int64(*(*int16)(unsafe.Pointer(&iv.word))) case Int32: - return int64(*(*int32)(vv.addr)) + return int64(*(*int32)(unsafe.Pointer(&iv.word))) case Int64: - return *(*int64)(vv.addr) + if iv.addr == nil { + return *(*int64)(unsafe.Pointer(&iv.word)) + } + return *(*int64)(iv.addr) } - panic("reflect: invalid int kind") + panic(&ValueError{"reflect.Value.Int", iv.kind}) +} + +// CanInterface returns true if Interface can be used without panicking. +func (v Value) CanInterface() bool { + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.CanInterface", iv.kind}) + } + // TODO(rsc): Check flagRO too. Decide what to do about asking for + // interface for a value obtained via an unexported field. + // If the field were of a known type, say chan int or *sync.Mutex, + // the caller could interfere with the data after getting the + // interface. But fmt.Print depends on being able to look. + // Now that reflect is more efficient the special cases in fmt + // might be less important. + return v.InternalMethod == 0 } // Interface returns v's value as an interface{}. @@ -463,34 +859,62 @@ func (v Value) Interface() interface{} { return v.internal().Interface() } +func (iv internalValue) Interface() interface{} { + if iv.method { + panic("reflect.Value.Interface: cannot create interface value for method with bound receiver") + } + /* + if v.flag()&noExport != 0 { + panic("reflect.Value.Interface: cannot return value obtained from unexported struct field") + } + */ + + if iv.kind == Interface { + // Special case: return the element inside the interface. + // Won't recurse further because an interface cannot contain an interface. + if iv.IsNil() { + return nil + } + return iv.Elem().Interface() + } + + // Non-interface value. + var eface emptyInterface + eface.typ = iv.typ.runtimeType() + eface.word = iv.word + return *(*interface{})(unsafe.Pointer(&eface)) +} + // InterfaceData returns the interface v's value as a uintptr pair. // It panics if v's Kind is not Interface. func (v Value) InterfaceData() [2]uintptr { - vv := v.panicIfNot(Interface).(*interfaceValue) - - return *(*[2]uintptr)(vv.addr) + iv := v.internal() + iv.mustBe(Interface) + // We treat this as a read operation, so we allow + // it even for unexported data, because the caller + // has to import "unsafe" to turn it into something + // that can be abused. + return *(*[2]uintptr)(iv.addr) } -var nilKinds = []Kind{Chan, Func, Interface, Map, Ptr, Slice} - // IsNil returns true if v is a nil value. // It panics if v's Kind is not Chan, Func, Interface, Map, Ptr, or Slice. func (v Value) IsNil() bool { - switch vv := v.panicIfNots(nilKinds).(type) { - case *chanValue: - return *(*uintptr)(vv.addr) == 0 - case *funcValue: - return *(*uintptr)(vv.addr) == 0 - case *interfaceValue: - return vv.Interface() == nil - case *mapValue: - return *(*uintptr)(vv.addr) == 0 - case *ptrValue: - return *(*uintptr)(vv.addr) == 0 - case *sliceValue: - return vv.slice().Data == 0 + return v.internal().IsNil() +} + +func (iv internalValue) IsNil() bool { + switch iv.kind { + case Chan, Func, Map, Ptr: + if iv.method { + panic("reflect: IsNil of method Value") + } + return iv.word == 0 + case Interface, Slice: + // Both interface and slice are nil if first word is 0. + return *(*uintptr)(iv.addr) == 0 } - panic("not reached") + panic(&ValueError{"reflect.Value.IsNil", iv.kind}) } // IsValid returns true if v represents a value. @@ -505,169 +929,179 @@ func (v Value) IsValid() bool { // Kind returns v's Kind. // If v is the zero Value (IsValid returns false), Kind returns Invalid. func (v Value) Kind() Kind { - if v.Internal == nil { - return Invalid - } - return v.internal().Kind() + return v.internal().kind } -var lenKinds = []Kind{Array, Chan, Map, Slice} - // Len returns v's length. // It panics if v's Kind is not Array, Chan, Map, or Slice. func (v Value) Len() int { - switch vv := v.panicIfNots(lenKinds).(type) { - case *arrayValue: - return vv.typ.Len() - case *chanValue: - ch := *(**byte)(vv.addr) - return int(chanlen(ch)) - case *mapValue: - m := *(**byte)(vv.addr) - if m == nil { - return 0 - } - return int(maplen(m)) - case *sliceValue: - return int(vv.slice().Len) + iv := v.internal() + switch iv.kind { + case Array: + return iv.typ.Len() + case Chan: + return int(chanlen(iv.word)) + case Map: + return int(maplen(iv.word)) + case Slice: + return (*SliceHeader)(iv.addr).Len } - panic("not reached") + panic(&ValueError{"reflect.Value.Len", iv.kind}) } // MapIndex returns the value associated with key in the map v. // It panics if v's Kind is not Map. -// It returns the zero Value if key is not found in the map. +// It returns the zero Value if key is not found in the map or if v represents a nil map. +// As in Go, the key's value must be assignable to the map's key type. func (v Value) MapIndex(key Value) Value { - vv := v.panicIfNot(Map).(*mapValue) - t := vv.Type() - typesMustMatch(t.Key(), key.Type()) - m := *(**byte)(vv.addr) - if m == nil { + iv := v.internal() + iv.mustBe(Map) + typ := iv.typ.toType() + + ikey := key.internal() + ikey.mustBeExported() + ikey = convertForAssignment("reflect.Value.MapIndex", nil, typ.Key(), ikey) + if iv.word == 0 { return Value{} } - newval := Zero(t.Elem()) - if !mapaccess(m, (*byte)(key.internal().getAddr()), (*byte)(newval.internal().getAddr())) { + + flag := iv.flag & flagRO + elemType := typ.Elem() + elemWord, ok := mapaccess(iv.word, ikey.word) + if !ok { return Value{} } - return newval + return valueFromIword(flag, elemType, elemWord) } // MapKeys returns a slice containing all the keys present in the map, // in unspecified order. // It panics if v's Kind is not Map. +// It returns an empty slice if v represents a nil map. func (v Value) MapKeys() []Value { - vv := v.panicIfNot(Map).(*mapValue) - tk := vv.Type().Key() - m := *(**byte)(vv.addr) + iv := v.internal() + iv.mustBe(Map) + keyType := iv.typ.Key() + + flag := iv.flag & flagRO + m := iv.word mlen := int32(0) - if m != nil { + if m != 0 { mlen = maplen(m) } it := mapiterinit(m) a := make([]Value, mlen) var i int for i = 0; i < len(a); i++ { - k := Zero(tk) - if !mapiterkey(it, (*byte)(k.internal().getAddr())) { + keyWord, ok := mapiterkey(it) + if !ok { break } - a[i] = k + a[i] = valueFromIword(flag, keyType, keyWord) mapiternext(it) } - return a[0:i] + return a[:i] } // Method returns a function value corresponding to v's i'th method. // The arguments to a Call on the returned function should not include // a receiver; the returned function will always use v as the receiver. +// Method panics if i is out of range. func (v Value) Method(i int) Value { - return v.internal().Method(i) + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.Method", Invalid}) + } + if i < 0 || i >= iv.typ.NumMethod() { + panic("reflect: Method index out of range") + } + return Value{v.Internal, i + 1} } // NumField returns the number of fields in the struct v. // It panics if v's Kind is not Struct. func (v Value) NumField() int { - return v.panicIfNot(Struct).(*structValue).typ.NumField() + iv := v.internal() + iv.mustBe(Struct) + return iv.typ.NumField() } // OverflowComplex returns true if the complex128 x cannot be represented by v's type. // It panics if v's Kind is not Complex64 or Complex128. func (v Value) OverflowComplex(x complex128) bool { - vv := v.panicIfNots(complexKinds).(*complexValue) - - if vv.typ.Size() == 16 { + iv := v.internal() + switch iv.kind { + case Complex64: + return overflowFloat32(real(x)) || overflowFloat32(imag(x)) + case Complex128: return false } - r := real(x) - i := imag(x) - if r < 0 { - r = -r - } - if i < 0 { - i = -i - } - return math.MaxFloat32 <= r && r <= math.MaxFloat64 || - math.MaxFloat32 <= i && i <= math.MaxFloat64 + panic(&ValueError{"reflect.Value.OverflowComplex", iv.kind}) } // OverflowFloat returns true if the float64 x cannot be represented by v's type. // It panics if v's Kind is not Float32 or Float64. func (v Value) OverflowFloat(x float64) bool { - vv := v.panicIfNots(floatKinds).(*floatValue) - - if vv.typ.Size() == 8 { + iv := v.internal() + switch iv.kind { + case Float32: + return overflowFloat32(x) + case Float64: return false } + panic(&ValueError{"reflect.Value.OverflowFloat", iv.kind}) +} + +func overflowFloat32(x float64) bool { if x < 0 { x = -x } - return math.MaxFloat32 < x && x <= math.MaxFloat64 + return math.MaxFloat32 <= x && x <= math.MaxFloat64 } // OverflowInt returns true if the int64 x cannot be represented by v's type. -// It panics if v's Kind is not a sized or unsized Int kind. +// It panics if v's Kind is not Int, Int8, int16, Int32, or Int64. func (v Value) OverflowInt(x int64) bool { - vv := v.panicIfNots(intKinds).(*intValue) - - bitSize := uint(vv.typ.Bits()) - trunc := (x << (64 - bitSize)) >> (64 - bitSize) - return x != trunc + iv := v.internal() + switch iv.kind { + case Int, Int8, Int16, Int32, Int64: + bitSize := iv.typ.size * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{"reflect.Value.OverflowInt", iv.kind}) } // OverflowUint returns true if the uint64 x cannot be represented by v's type. -// It panics if v's Kind is not a sized or unsized Uint kind. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. func (v Value) OverflowUint(x uint64) bool { - vv := v.panicIfNots(uintKinds).(*uintValue) - - bitSize := uint(vv.typ.Bits()) - trunc := (x << (64 - bitSize)) >> (64 - bitSize) - return x != trunc + iv := v.internal() + switch iv.kind { + case Uint, Uintptr, Uint8, Uint16, Uint32, Uint64: + bitSize := iv.typ.size * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{"reflect.Value.OverflowUint", iv.kind}) } -var pointerKinds = []Kind{Chan, Func, Map, Ptr, Slice, UnsafePointer} - // Pointer returns v's value as a uintptr. // It returns uintptr instead of unsafe.Pointer so that // code using reflect cannot obtain unsafe.Pointers // without importing the unsafe package explicitly. // It panics if v's Kind is not Chan, Func, Map, Ptr, Slice, or UnsafePointer. func (v Value) Pointer() uintptr { - switch vv := v.panicIfNots(pointerKinds).(type) { - case *chanValue: - return *(*uintptr)(vv.addr) - case *funcValue: - return *(*uintptr)(vv.addr) - case *mapValue: - return *(*uintptr)(vv.addr) - case *ptrValue: - return *(*uintptr)(vv.addr) - case *sliceValue: - typ := vv.typ - return uintptr(vv.addr()) + uintptr(v.Cap())*typ.Elem().Size() - case *unsafePointerValue: - return uintptr(*(*unsafe.Pointer)(vv.addr)) + iv := v.internal() + switch iv.kind { + case Chan, Func, Map, Ptr, UnsafePointer: + if iv.kind == Func && v.InternalMethod != 0 { + panic("reflect.Value.Pointer of method Value") + } + return uintptr(iv.word) + case Slice: + return (*SliceHeader)(iv.addr).Data } - panic("not reached") + panic(&ValueError{"reflect.Value.Pointer", iv.kind}) } // Recv receives and returns a value from the channel v. @@ -676,233 +1110,142 @@ func (v Value) Pointer() uintptr { // The boolean value ok is true if the value x corresponds to a send // on the channel, false if it is a zero value received because the channel is closed. func (v Value) Recv() (x Value, ok bool) { - return v.panicIfNot(Chan).(*chanValue).recv(nil) + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.recv(false) } -// internal recv; non-blocking if selected != nil -func (v *chanValue) recv(selected *bool) (Value, bool) { - t := v.Type() +// internal recv, possibly non-blocking (nb) +func (iv internalValue) recv(nb bool) (val Value, ok bool) { + t := iv.typ.toType() if t.ChanDir()&RecvDir == 0 { panic("recv on send-only channel") } - ch := *(**byte)(v.addr) - x := Zero(t.Elem()) - var ok bool - chanrecv(ch, (*byte)(x.internal().getAddr()), selected, &ok) - return x, ok + ch := iv.word + if ch == 0 { + panic("recv on nil channel") + } + valWord, selected, ok := chanrecv(ch, nb) + if selected { + val = valueFromIword(0, t.Elem(), valWord) + } + return } // Send sends x on the channel v. // It panics if v's kind is not Chan or if x's type is not the same type as v's element type. +// As in Go, x's value must be assignable to the channel's element type. func (v Value) Send(x Value) { - v.panicIfNot(Chan).(*chanValue).send(x, nil) + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + iv.send(x, false) } -// internal send; non-blocking if selected != nil -func (v *chanValue) send(x Value, selected *bool) { - t := v.Type() +// internal send, possibly non-blocking +func (iv internalValue) send(x Value, nb bool) (selected bool) { + t := iv.typ.toType() if t.ChanDir()&SendDir == 0 { panic("send on recv-only channel") } - typesMustMatch(t.Elem(), x.Type()) - ch := *(**byte)(v.addr) - chansend(ch, (*byte)(x.internal().getAddr()), selected) + ix := x.internal() + ix.mustBeExported() // do not let unexported x leak + ix = convertForAssignment("reflect.Value.Send", nil, t.Elem(), ix) + ch := iv.word + if ch == 0 { + panic("send on nil channel") + } + return chansend(ch, ix.word, nb) } -// Set assigns x to the value v; x must have the same type as v. -// It panics if CanSet() returns false or if x is the zero Value. +// Set assigns x to the value v. +// It panics if CanSet returns false. +// As in Go, x's value must be assignable to v's type. func (v Value) Set(x Value) { - x.internal() - switch vv := v.internal().(type) { - case *arrayValue: - xx := x.panicIfNot(Array).(*arrayValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, xx.typ) - Copy(v, x) + iv := v.internal() + ix := x.internal() - case *boolValue: - v.SetBool(x.Bool()) - - case *chanValue: - x := x.panicIfNot(Chan).(*chanValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) - - case *floatValue: - v.SetFloat(x.Float()) - - case *funcValue: - x := x.panicIfNot(Func).(*funcValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) + iv.mustBeAssignable() + ix.mustBeExported() // do not let unexported x leak - case *intValue: - v.SetInt(x.Int()) + ix = convertForAssignment("reflect.Set", iv.addr, iv.typ, ix) - case *interfaceValue: - i := x.Interface() - if !vv.CanSet() { - panic(cannotSet) - } - // Two different representations; see comment in Get. - // Empty interface is easy. - t := (*interfaceType)(unsafe.Pointer(vv.typ.(*commonType))) - if t.NumMethod() == 0 { - *(*interface{})(vv.addr) = i - return - } - - // Non-empty interface requires a runtime check. - setiface(t, &i, vv.addr) - - case *mapValue: - x := x.panicIfNot(Map).(*mapValue) - if !vv.CanSet() { - panic(cannotSet) - } - if x == nil { - *(**uintptr)(vv.addr) = nil - return - } - typesMustMatch(vv.typ, x.typ) - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) - - case *ptrValue: - x := x.panicIfNot(Ptr).(*ptrValue) - if x == nil { - *(**uintptr)(vv.addr) = nil - return - } - if !vv.CanSet() { - panic(cannotSet) - } - if x.flag&canStore == 0 { - panic("cannot copy pointer obtained from unexported struct field") - } - typesMustMatch(vv.typ, x.typ) - // TODO: This will have to move into the runtime - // once the new gc goes in - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) - - case *sliceValue: - x := x.panicIfNot(Slice).(*sliceValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - *vv.slice() = *x.slice() - - case *stringValue: - // Do the kind check explicitly, because x.String() does not. - x.panicIfNot(String) - v.SetString(x.String()) - - case *structValue: - x := x.panicIfNot(Struct).(*structValue) - // TODO: This will have to move into the runtime - // once the gc goes in. - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - memmove(vv.addr, x.addr, vv.typ.Size()) - - case *uintValue: - v.SetUint(x.Uint()) - - case *unsafePointerValue: - // Do the kind check explicitly, because x.UnsafePointer - // applies to more than just the UnsafePointer Kind. - x.panicIfNot(UnsafePointer) - v.SetPointer(unsafe.Pointer(x.Pointer())) + n := ix.typ.size + if n <= ptrSize { + storeIword(iv.addr, ix.word, n) + } else { + memmove(iv.addr, ix.addr, n) } } // SetBool sets v's underlying value. // It panics if v's Kind is not Bool or if CanSet() is false. func (v Value) SetBool(x bool) { - vv := v.panicIfNot(Bool).(*boolValue) - - if !vv.CanSet() { - panic(cannotSet) - } - *(*bool)(vv.addr) = x + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(Bool) + *(*bool)(iv.addr) = x } // SetComplex sets v's underlying value to x. // It panics if v's Kind is not Complex64 or Complex128, or if CanSet() is false. func (v Value) SetComplex(x complex128) { - vv := v.panicIfNots(complexKinds).(*complexValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid complex kind") + panic(&ValueError{"reflect.Value.SetComplex", iv.kind}) case Complex64: - *(*complex64)(vv.addr) = complex64(x) + *(*complex64)(iv.addr) = complex64(x) case Complex128: - *(*complex128)(vv.addr) = x + *(*complex128)(iv.addr) = x } } // SetFloat sets v's underlying value to x. // It panics if v's Kind is not Float32 or Float64, or if CanSet() is false. func (v Value) SetFloat(x float64) { - vv := v.panicIfNots(floatKinds).(*floatValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid float kind") + panic(&ValueError{"reflect.Value.SetFloat", iv.kind}) case Float32: - *(*float32)(vv.addr) = float32(x) + *(*float32)(iv.addr) = float32(x) case Float64: - *(*float64)(vv.addr) = x + *(*float64)(iv.addr) = x } } // SetInt sets v's underlying value to x. -// It panics if v's Kind is not a sized or unsized Int kind, or if CanSet() is false. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64, or if CanSet() is false. func (v Value) SetInt(x int64) { - vv := v.panicIfNots(intKinds).(*intValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid int kind") + panic(&ValueError{"reflect.Value.SetInt", iv.kind}) case Int: - *(*int)(vv.addr) = int(x) + *(*int)(iv.addr) = int(x) case Int8: - *(*int8)(vv.addr) = int8(x) + *(*int8)(iv.addr) = int8(x) case Int16: - *(*int16)(vv.addr) = int16(x) + *(*int16)(iv.addr) = int16(x) case Int32: - *(*int32)(vv.addr) = int32(x) + *(*int32)(iv.addr) = int32(x) case Int64: - *(*int64)(vv.addr) = x + *(*int64)(iv.addr) = x } } // SetLen sets v's length to n. // It panics if v's Kind is not Slice. func (v Value) SetLen(n int) { - vv := v.panicIfNot(Slice).(*sliceValue) - - s := vv.slice() + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(Slice) + s := (*SliceHeader)(iv.addr) if n < 0 || n > int(s.Cap) { panic("reflect: slice length out of range in SetLen") } @@ -912,91 +1255,97 @@ func (v Value) SetLen(n int) { // SetMapIndex sets the value associated with key in the map v to val. // It panics if v's Kind is not Map. // If val is the zero Value, SetMapIndex deletes the key from the map. +// As in Go, key's value must be assignable to the map's key type, +// and val's value must be assignable to the map's value type. func (v Value) SetMapIndex(key, val Value) { - vv := v.panicIfNot(Map).(*mapValue) - t := vv.Type() - typesMustMatch(t.Key(), key.Type()) - var vaddr *byte - if val.IsValid() { - typesMustMatch(t.Elem(), val.Type()) - vaddr = (*byte)(val.internal().getAddr()) + iv := v.internal() + ikey := key.internal() + ival := val.internal() + + iv.mustBe(Map) + iv.mustBeExported() + + ikey.mustBeExported() + ikey = convertForAssignment("reflect.Value.SetMapIndex", nil, iv.typ.Key(), ikey) + + if ival.kind != Invalid { + ival.mustBeExported() + ival = convertForAssignment("reflect.Value.SetMapIndex", nil, iv.typ.Elem(), ival) } - m := *(**byte)(vv.addr) - mapassign(m, (*byte)(key.internal().getAddr()), vaddr) + + mapassign(iv.word, ikey.word, ival.word, ival.kind != Invalid) } // SetUint sets v's underlying value to x. -// It panics if v's Kind is not a sized or unsized Uint kind, or if CanSet() is false. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64, or if CanSet() is false. func (v Value) SetUint(x uint64) { - vv := v.panicIfNots(uintKinds).(*uintValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid uint kind") + panic(&ValueError{"reflect.Value.SetUint", iv.kind}) case Uint: - *(*uint)(vv.addr) = uint(x) + *(*uint)(iv.addr) = uint(x) case Uint8: - *(*uint8)(vv.addr) = uint8(x) + *(*uint8)(iv.addr) = uint8(x) case Uint16: - *(*uint16)(vv.addr) = uint16(x) + *(*uint16)(iv.addr) = uint16(x) case Uint32: - *(*uint32)(vv.addr) = uint32(x) + *(*uint32)(iv.addr) = uint32(x) case Uint64: - *(*uint64)(vv.addr) = x + *(*uint64)(iv.addr) = x case Uintptr: - *(*uintptr)(vv.addr) = uintptr(x) + *(*uintptr)(iv.addr) = uintptr(x) } } // SetPointer sets the unsafe.Pointer value v to x. // It panics if v's Kind is not UnsafePointer. func (v Value) SetPointer(x unsafe.Pointer) { - vv := v.panicIfNot(UnsafePointer).(*unsafePointerValue) - - if !vv.CanSet() { - panic(cannotSet) - } - *(*unsafe.Pointer)(vv.addr) = x + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(UnsafePointer) + *(*unsafe.Pointer)(iv.addr) = x } // SetString sets v's underlying value to x. // It panics if v's Kind is not String or if CanSet() is false. func (v Value) SetString(x string) { - vv := v.panicIfNot(String).(*stringValue) - - if !vv.CanSet() { - panic(cannotSet) - } - *(*string)(vv.addr) = x + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(String) + *(*string)(iv.addr) = x } -// BUG(rsc): Value.Slice should allow slicing arrays. - // Slice returns a slice of v. -// It panics if v's Kind is not Slice. +// It panics if v's Kind is not Array or Slice. func (v Value) Slice(beg, end int) Value { - vv := v.panicIfNot(Slice).(*sliceValue) - + iv := v.internal() + if iv.kind != Array && iv.kind != Slice { + panic(&ValueError{"reflect.Value.Slice", iv.kind}) + } cap := v.Cap() if beg < 0 || end < beg || end > cap { - panic("slice index out of bounds") + panic("reflect.Value.Slice: slice index out of bounds") + } + var typ Type + var base uintptr + switch iv.kind { + case Array: + if iv.flag&flagAddr == 0 { + panic("reflect.Value.Slice: slice of unaddressable array") + } + typ = toType((*arrayType)(unsafe.Pointer(iv.typ)).slice) + base = uintptr(iv.addr) + case Slice: + typ = iv.typ.toType() + base = (*SliceHeader)(iv.addr).Data } - typ := vv.typ s := new(SliceHeader) - s.Data = uintptr(vv.addr()) + uintptr(beg)*typ.Elem().Size() + s.Data = base + uintptr(beg)*typ.Elem().Size() s.Len = end - beg s.Cap = cap - beg - - // Like the result of Addr, we treat Slice as an - // unaddressable temporary, so don't set canAddr. - flag := canSet - if vv.flag&canStore != 0 { - flag |= canStore - } - return newValue(typ, addr(s), flag) + return valueFromAddr(iv.flag&flagRO, typ, unsafe.Pointer(s)) } // String returns the string v's underlying value, as a string. @@ -1004,15 +1353,14 @@ func (v Value) Slice(beg, end int) Value { // Unlike the other getters, it does not panic if v's Kind is not String. // Instead, it returns a string of the form "<T value>" where T is v's type. func (v Value) String() string { - vi := v.Internal - if vi == nil { + iv := v.internal() + switch iv.kind { + case Invalid: return "<invalid Value>" + case String: + return *(*string)(iv.addr) } - if vi.Kind() == String { - vv := vi.(*stringValue) - return *(*string)(vv.addr) - } - return "<" + vi.Type().String() + " Value>" + return "<" + iv.typ.String() + " Value>" } // TryRecv attempts to receive a value from the channel v but will not block. @@ -1021,241 +1369,98 @@ func (v Value) String() string { // The boolean ok is true if the value x corresponds to a send // on the channel, false if it is a zero value received because the channel is closed. func (v Value) TryRecv() (x Value, ok bool) { - vv := v.panicIfNot(Chan).(*chanValue) - - var selected bool - x, ok = vv.recv(&selected) - if !selected { - return Value{}, false - } - return x, ok + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.recv(true) } // TrySend attempts to send x on the channel v but will not block. // It panics if v's Kind is not Chan. // It returns true if the value was sent, false otherwise. +// As in Go, x's value must be assignable to the channel's element type. func (v Value) TrySend(x Value) bool { - vv := v.panicIfNot(Chan).(*chanValue) - - var selected bool - vv.send(x, &selected) - return selected + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.send(x, true) } // Type returns v's type. func (v Value) Type() Type { - return v.internal().Type() + t := v.internal().typ + if t == nil { + panic(&ValueError{"reflect.Value.Type", Invalid}) + } + return t.toType() } -var uintKinds = []Kind{Uint, Uint8, Uint16, Uint32, Uint64, Uintptr} - // Uint returns v's underlying value, as a uint64. -// It panics if v's Kind is not a sized or unsized Uint kind. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. func (v Value) Uint() uint64 { - vv := v.panicIfNots(uintKinds).(*uintValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Uint: - return uint64(*(*uint)(vv.addr)) + return uint64(*(*uint)(unsafe.Pointer(&iv.word))) case Uint8: - return uint64(*(*uint8)(vv.addr)) + return uint64(*(*uint8)(unsafe.Pointer(&iv.word))) case Uint16: - return uint64(*(*uint16)(vv.addr)) + return uint64(*(*uint16)(unsafe.Pointer(&iv.word))) case Uint32: - return uint64(*(*uint32)(vv.addr)) - case Uint64: - return *(*uint64)(vv.addr) + return uint64(*(*uint32)(unsafe.Pointer(&iv.word))) case Uintptr: - return uint64(*(*uintptr)(vv.addr)) + return uint64(*(*uintptr)(unsafe.Pointer(&iv.word))) + case Uint64: + if iv.addr == nil { + return *(*uint64)(unsafe.Pointer(&iv.word)) + } + return *(*uint64)(iv.addr) } - panic("reflect: invalid uint kind") + panic(&ValueError{"reflect.Value.Uint", iv.kind}) } // UnsafeAddr returns a pointer to v's data. // It is for advanced clients that also import the "unsafe" package. +// It panics if v is not addressable. func (v Value) UnsafeAddr() uintptr { - return v.internal().UnsafeAddr() -} - -// valueInterface is the common interface to reflection values. -// The implementations of Value (e.g., arrayValue, structValue) -// have additional type-specific methods. -type valueInterface interface { - // Type returns the value's type. - Type() Type - - // Interface returns the value as an interface{}. - Interface() interface{} - - // CanSet returns true if the value can be changed. - // Values obtained by the use of non-exported struct fields - // can be used in Get but not Set. - // If CanSet returns false, calling the type-specific Set will panic. - CanSet() bool - - // CanAddr returns true if the value's address can be obtained with Addr. - // Such values are called addressable. A value is addressable if it is - // an element of a slice, an element of an addressable array, - // a field of an addressable struct, the result of dereferencing a pointer, - // or the result of a call to NewValue, MakeChan, MakeMap, or Zero. - // If CanAddr returns false, calling Addr will panic. - CanAddr() bool - - // Addr returns the address of the value. - // If the value is not addressable, Addr panics. - // Addr is typically used to obtain a pointer to a struct field or slice element - // in order to call a method that requires a pointer receiver. - Addr() Value - - // UnsafeAddr returns a pointer to the underlying data. - // It is for advanced clients that also import the "unsafe" package. - UnsafeAddr() uintptr - - // Method returns a funcValue corresponding to the value's i'th method. - // The arguments to a Call on the returned funcValue - // should not include a receiver; the funcValue will use - // the value as the receiver. - Method(i int) Value - - Kind() Kind - - getAddr() addr -} - -// flags for value -const ( - canSet uint32 = 1 << iota // can set value (write to *v.addr) - canAddr // can take address of value - canStore // can store through value (write to **v.addr) -) - -// value is the common implementation of most values. -// It is embedded in other, public struct types, but always -// with a unique tag like "uint" or "float" so that the client cannot -// convert from, say, *uintValue to *floatValue. -type value struct { - typ Type - addr addr - flag uint32 -} - -func (v *value) Type() Type { return v.typ } - -func (v *value) Kind() Kind { return v.typ.Kind() } - -func (v *value) Addr() Value { - if !v.CanAddr() { - panic("reflect: cannot take address of value") + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.UnsafeAddr", iv.kind}) } - a := v.addr - flag := canSet - if v.CanSet() { - flag |= canStore + if iv.flag&flagAddr == 0 { + panic("reflect.Value.UnsafeAddr of unaddressable value") } - // We could safely set canAddr here too - - // the caller would get the address of a - - // but it doesn't match the Go model. - // The language doesn't let you say &&v. - return newValue(PtrTo(v.typ), addr(&a), flag) -} - -func (v *value) UnsafeAddr() uintptr { return uintptr(v.addr) } - -func (v *value) getAddr() addr { return v.addr } - -func (v *value) Interface() interface{} { - typ := v.typ - if typ.Kind() == Interface { - // There are two different representations of interface values, - // one if the interface type has methods and one if it doesn't. - // These two representations require different expressions - // to extract correctly. - if typ.NumMethod() == 0 { - // Extract as interface value without methods. - return *(*interface{})(v.addr) - } - // Extract from v.addr as interface value with methods. - return *(*interface { - m() - })(v.addr) - } - return unsafe.Unreflect(v.typ, unsafe.Pointer(v.addr)) -} - -func (v *value) CanSet() bool { return v.flag&canSet != 0 } - -func (v *value) CanAddr() bool { return v.flag&canAddr != 0 } - - -/* - * basic types - */ - -// boolValue represents a bool value. -type boolValue struct { - value "bool" -} - -// floatValue represents a float value. -type floatValue struct { - value "float" -} - -// complexValue represents a complex value. -type complexValue struct { - value "complex" -} - -// intValue represents an int value. -type intValue struct { - value "int" + return uintptr(iv.addr) } // StringHeader is the runtime representation of a string. +// It cannot be used safely or portably. type StringHeader struct { Data uintptr Len int } -// stringValue represents a string value. -type stringValue struct { - value "string" -} - -// uintValue represents a uint value. -type uintValue struct { - value "uint" -} - -// unsafePointerValue represents an unsafe.Pointer value. -type unsafePointerValue struct { - value "unsafe.Pointer" +// SliceHeader is the runtime representation of a slice. +// It cannot be used safely or portably. +type SliceHeader struct { + Data uintptr + Len int + Cap int } -func typesMustMatch(t1, t2 Type) { +func typesMustMatch(what string, t1, t2 Type) { if t1 != t2 { - panic("type mismatch: " + t1.String() + " != " + t2.String()) + panic("reflect: " + what + ": " + t1.String() + " != " + t2.String()) } } -/* - * array - */ - -// ArrayOrSliceValue is the common interface -// implemented by both arrayValue and sliceValue. -type arrayOrSliceValue interface { - valueInterface - addr() addr -} - // grow grows the slice s so that it can hold extra more values, allocating // more capacity if needed. It also returns the old and new slice lengths. func grow(s Value, extra int) (Value, int, int) { i0 := s.Len() i1 := i0 + extra if i1 < i0 { - panic("append: slice overflow") + panic("reflect.Append: slice overflow") } m := s.Cap() if i1 <= m { @@ -1278,10 +1483,10 @@ func grow(s Value, extra int) (Value, int, int) { } // Append appends the values x to a slice s and returns the resulting slice. -// Each x must have the same type as s' element type. +// As in Go, each x's value must be assignable to the slice's element type. func Append(s Value, x ...Value) Value { + s.internal().mustBe(Slice) s, i0, i1 := grow(s, len(x)) - s.panicIfNot(Slice) for i, j := i0, 0; i < i1; i, j = i+1, j+1 { s.Index(i).Set(x[j]) } @@ -1291,6 +1496,9 @@ func Append(s Value, x ...Value) Value { // AppendSlice appends a slice t to a slice s and returns the resulting slice. // The slices s and t must have the same element type. func AppendSlice(s, t Value) Value { + s.internal().mustBe(Slice) + t.internal().mustBe(Slice) + typesMustMatch("reflect.AppendSlice", s.Type().Elem(), t.Type().Elem()) s, i0, i1 := grow(s, t.Len()) Copy(s.Slice(i0, i1), t) return s @@ -1299,52 +1507,61 @@ func AppendSlice(s, t Value) Value { // Copy copies the contents of src into dst until either // dst has been filled or src has been exhausted. // It returns the number of elements copied. -// Dst and src each must be a slice or array, and they -// must have the same element type. +// Dst and src each must have kind Slice or Array, and +// dst and src must have the same element type. func Copy(dst, src Value) int { - // TODO: This will have to move into the runtime - // once the real gc goes in. - de := dst.Type().Elem() - se := src.Type().Elem() - typesMustMatch(de, se) - n := dst.Len() - if xn := src.Len(); n > xn { - n = xn - } - memmove(dst.panicIfNots(arrayOrSlice).(arrayOrSliceValue).addr(), - src.panicIfNots(arrayOrSlice).(arrayOrSliceValue).addr(), - uintptr(n)*de.Size()) - return n -} + idst := dst.internal() + isrc := src.internal() -// An arrayValue represents an array. -type arrayValue struct { - value "array" -} + if idst.kind != Array && idst.kind != Slice { + panic(&ValueError{"reflect.Copy", idst.kind}) + } + if idst.kind == Array { + idst.mustBeAssignable() + } + idst.mustBeExported() + if isrc.kind != Array && isrc.kind != Slice { + panic(&ValueError{"reflect.Copy", isrc.kind}) + } + isrc.mustBeExported() -// addr returns the base address of the data in the array. -func (v *arrayValue) addr() addr { return v.value.addr } + de := idst.typ.Elem() + se := isrc.typ.Elem() + typesMustMatch("reflect.Copy", de, se) -/* - * slice - */ + n := dst.Len() + if sn := src.Len(); n > sn { + n = sn + } -// runtime representation of slice -type SliceHeader struct { - Data uintptr - Len int - Cap int -} + // If sk is an in-line array, cannot take its address. + // Instead, copy element by element. + if isrc.addr == nil { + for i := 0; i < n; i++ { + dst.Index(i).Set(src.Index(i)) + } + return n + } -// A sliceValue represents a slice. -type sliceValue struct { - value "slice" + // Copy via memmove. + var da, sa unsafe.Pointer + if idst.kind == Array { + da = idst.addr + } else { + da = unsafe.Pointer((*SliceHeader)(idst.addr).Data) + } + if isrc.kind == Array { + sa = isrc.addr + } else { + sa = unsafe.Pointer((*SliceHeader)(isrc.addr).Data) + } + memmove(da, sa, uintptr(n)*de.Size()) + return n } -func (v *sliceValue) slice() *SliceHeader { return (*SliceHeader)(v.value.addr) } - -// addr returns the base address of the data in the slice. -func (v *sliceValue) addr() addr { return addr(v.slice().Data) } +/* + * constructors + */ // MakeSlice creates a new zero-initialized slice value // for the specified slice type, length, and capacity. @@ -1357,26 +1574,9 @@ func MakeSlice(typ Type, len, cap int) Value { Len: len, Cap: cap, } - return newValue(typ, addr(s), canAddr|canSet|canStore) -} - -/* - * chan - */ - -// A chanValue represents a chan. -type chanValue struct { - value "chan" + return valueFromAddr(0, typ, unsafe.Pointer(s)) } -// implemented in ../pkg/runtime/reflect.cgo -func makechan(typ *runtime.ChanType, size uint32) (ch *byte) -func chansend(ch, val *byte, selected *bool) -func chanrecv(ch, val *byte, selected *bool, ok *bool) -func chanclose(ch *byte) -func chanlen(ch *byte) int32 -func chancap(ch *byte) int32 - // MakeChan creates a new channel with the specified type and buffer size. func MakeChan(typ Type, buffer int) Value { if typ.Kind() != Chan { @@ -1388,121 +1588,17 @@ func MakeChan(typ Type, buffer int) Value { if typ.ChanDir() != BothDir { panic("MakeChan: unidirectional channel type") } - v := Zero(typ) - ch := v.panicIfNot(Chan).(*chanValue) - *(**byte)(ch.addr) = makechan((*runtime.ChanType)(unsafe.Pointer(typ.(*commonType))), uint32(buffer)) - return v + ch := makechan(typ.runtimeType(), uint32(buffer)) + return valueFromIword(0, typ, ch) } -/* - * func - */ - -// A funcValue represents a function value. -type funcValue struct { - value "func" - first *value - isInterface bool -} - -// Method returns a funcValue corresponding to v's i'th method. -// The arguments to a Call on the returned funcValue -// should not include a receiver; the funcValue will use v -// as the receiver. -func (v *value) Method(i int) Value { - t := v.Type().uncommon() - if t == nil || i < 0 || i >= len(t.methods) { - panic("reflect: Method index out of range") - } - p := &t.methods[i] - fn := p.tfn - fv := &funcValue{value: value{toType(p.typ), addr(&fn), 0}, first: v, isInterface: false} - return Value{fv} -} - -// implemented in ../pkg/runtime/*/asm.s -func call(fn, arg *byte, n uint32) - -// Interface returns the fv as an interface value. -// If fv is a method obtained by invoking Value.Method -// (as opposed to Type.Method), Interface cannot return an -// interface value, so it panics. -func (fv *funcValue) Interface() interface{} { - if fv.first != nil { - panic("funcValue: cannot create interface value for method with bound receiver") - } - return fv.value.Interface() -} - -/* - * interface - */ - -// An interfaceValue represents an interface value. -type interfaceValue struct { - value "interface" -} - -// ../runtime/reflect.cgo -func setiface(typ *interfaceType, x *interface{}, addr addr) - -// Method returns a funcValue corresponding to v's i'th method. -// The arguments to a Call on the returned funcValue -// should not include a receiver; the funcValue will use v -// as the receiver. -func (v *interfaceValue) Method(i int) Value { - t := (*interfaceType)(unsafe.Pointer(v.Type().(*commonType))) - if t == nil || i < 0 || i >= len(t.methods) { - panic("reflect: Method index out of range") - } - p := &t.methods[i] - - // Interface is two words: itable, data. - tab := *(**runtime.Itable)(v.addr) - data := &value{Typeof((*byte)(nil)), addr(uintptr(v.addr) + ptrSize), 0} - - // Function pointer is at p.perm in the table. - fn := tab.Fn[i] - fv := &funcValue{value: value{toType(p.typ), addr(&fn), 0}, first: data, isInterface: true} - return Value{fv} -} - -/* - * map - */ - -// A mapValue represents a map value. -type mapValue struct { - value "map" -} - -// implemented in ../pkg/runtime/reflect.cgo -func mapaccess(m, key, val *byte) bool -func mapassign(m, key, val *byte) -func maplen(m *byte) int32 -func mapiterinit(m *byte) *byte -func mapiternext(it *byte) -func mapiterkey(it *byte, key *byte) bool -func makemap(t *runtime.MapType) *byte - // MakeMap creates a new map of the specified type. func MakeMap(typ Type) Value { if typ.Kind() != Map { panic("reflect: MakeMap of non-map type") } - v := Zero(typ) - m := v.panicIfNot(Map).(*mapValue) - *(**byte)(m.addr) = makemap((*runtime.MapType)(unsafe.Pointer(typ.(*commonType)))) - return v -} - -/* - * ptr - */ - -// A ptrValue represents a pointer. -type ptrValue struct { - value "ptr" + m := makemap(typ.runtimeType()) + return valueFromIword(0, typ, m) } // Indirect returns the value that v points to. @@ -1515,73 +1611,90 @@ func Indirect(v Value) Value { return v.Elem() } -/* - * struct - */ - -// A structValue represents a struct value. -type structValue struct { - value "struct" -} - -/* - * constructors - */ - -// NewValue returns a new Value initialized to the concrete value -// stored in the interface i. NewValue(nil) returns the zero Value. -func NewValue(i interface{}) Value { +// ValueOf returns a new Value initialized to the concrete value +// stored in the interface i. ValueOf(nil) returns the zero Value. +func ValueOf(i interface{}) Value { if i == nil { return Value{} } - _, a := unsafe.Reflect(i) - return newValue(Typeof(i), addr(a), canSet|canAddr|canStore) -} - -func newValue(typ Type, addr addr, flag uint32) Value { - v := value{typ, addr, flag} - switch typ.Kind() { - case Array: - return Value{&arrayValue{v}} - case Bool: - return Value{&boolValue{v}} - case Chan: - return Value{&chanValue{v}} - case Float32, Float64: - return Value{&floatValue{v}} - case Func: - return Value{&funcValue{value: v}} - case Complex64, Complex128: - return Value{&complexValue{v}} - case Int, Int8, Int16, Int32, Int64: - return Value{&intValue{v}} - case Interface: - return Value{&interfaceValue{v}} - case Map: - return Value{&mapValue{v}} - case Ptr: - return Value{&ptrValue{v}} - case Slice: - return Value{&sliceValue{v}} - case String: - return Value{&stringValue{v}} - case Struct: - return Value{&structValue{v}} - case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr: - return Value{&uintValue{v}} - case UnsafePointer: - return Value{&unsafePointerValue{v}} - } - panic("newValue" + typ.String()) + // For an interface value with the noAddr bit set, + // the representation is identical to an empty interface. + eface := *(*emptyInterface)(unsafe.Pointer(&i)) + return packValue(0, eface.typ, eface.word) } // Zero returns a Value representing a zero value for the specified type. // The result is different from the zero value of the Value struct, // which represents no value at all. -// For example, Zero(Typeof(42)) returns a Value with Kind Int and value 0. +// For example, Zero(TypeOf(42)) returns a Value with Kind Int and value 0. func Zero(typ Type) Value { if typ == nil { panic("reflect: Zero(nil)") } - return newValue(typ, addr(unsafe.New(typ)), canSet|canAddr|canStore) + if typ.Size() <= ptrSize { + return valueFromIword(0, typ, 0) + } + return valueFromAddr(0, typ, unsafe.New(typ)) } + +// New returns a Value representing a pointer to a new zero value +// for the specified type. That is, the returned Value's Type is PtrTo(t). +func New(typ Type) Value { + if typ == nil { + panic("reflect: New(nil)") + } + ptr := unsafe.New(typ) + return valueFromIword(0, PtrTo(typ), iword(ptr)) +} + +// convertForAssignment +func convertForAssignment(what string, addr unsafe.Pointer, dst Type, iv internalValue) internalValue { + if iv.method { + panic(what + ": cannot assign method value to type " + dst.String()) + } + + dst1 := dst.(*commonType) + if directlyAssignable(dst1, iv.typ) { + // Overwrite type so that they match. + // Same memory layout, so no harm done. + iv.typ = dst1 + return iv + } + if implements(dst1, iv.typ) { + if addr == nil { + addr = unsafe.Pointer(new(interface{})) + } + x := iv.Interface() + if dst.NumMethod() == 0 { + *(*interface{})(addr) = x + } else { + ifaceE2I(dst1.runtimeType(), x, addr) + } + iv.addr = addr + iv.word = iword(addr) + iv.typ = dst1 + return iv + } + + // Failed. + panic(what + ": value of type " + iv.typ.String() + " is not assignable to type " + dst.String()) +} + +// implemented in ../pkg/runtime +func chancap(ch iword) int32 +func chanclose(ch iword) +func chanlen(ch iword) int32 +func chanrecv(ch iword, nb bool) (val iword, selected, received bool) +func chansend(ch iword, val iword, nb bool) bool + +func makechan(typ *runtime.Type, size uint32) (ch iword) +func makemap(t *runtime.Type) iword +func mapaccess(m iword, key iword) (val iword, ok bool) +func mapassign(m iword, key, val iword, ok bool) +func mapiterinit(m iword) *byte +func mapiterkey(it *byte) (key iword, ok bool) +func mapiternext(it *byte) +func maplen(m iword) int32 + +func call(fn, arg unsafe.Pointer, n uint32) +func ifaceE2I(t *runtime.Type, src interface{}, dst unsafe.Pointer) diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index af31a65cc..acadeec37 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The rpc package provides access to the exported methods of an object across a + Package rpc provides access to the exported methods of an object across a network or other I/O connection. A server registers an object, making it visible as a service with the name of the type of the object. After registration, exported methods of the object will be accessible remotely. A server may register multiple @@ -13,8 +13,11 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method receiver and name are exported, that is, begin with an upper case letter. - - the method has two arguments, both pointers to exported types. + - the method name is exported, that is, begins with an upper case letter. + - the method receiver is exported or local (defined in the package + registering the service). + - the method has two arguments, both exported or local types. + - the method's second argument is a pointer. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -133,7 +136,7 @@ const ( // Precompute the reflect type for os.Error. Can't use os.Error directly // because Typeof takes an empty interface value. This is annoying. var unusedError *os.Error -var typeOfOsError = reflect.Typeof(unusedError).Elem() +var typeOfOsError = reflect.TypeOf(unusedError).Elem() type methodType struct { sync.Mutex // protects counters @@ -193,6 +196,14 @@ func isExported(name string) bool { return unicode.IsUpper(rune) } +// Is this type exported or local to this package? +func isExportedOrLocalType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() == "" || isExported(t.Name()) +} + // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method @@ -219,8 +230,8 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E server.serviceMap = make(map[string]*service) } s := new(service) - s.typ = reflect.Typeof(rcvr) - s.rcvr = reflect.NewValue(rcvr) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) sname := reflect.Indirect(s.rcvr).Type().Name() if useName { sname = name @@ -252,23 +263,20 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } + // First arg need not be a pointer. argType := mtype.In(1) - ok := argType.Kind() == reflect.Ptr - if !ok { - log.Println(mname, "arg type not a pointer:", mtype.In(1)) + if !isExportedOrLocalType(argType) { + log.Println(mname, "argument type not exported or local:", argType) continue } + // Second arg must be a pointer. replyType := mtype.In(2) if replyType.Kind() != reflect.Ptr { - log.Println(mname, "reply type not a pointer:", mtype.In(2)) - continue - } - if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { - log.Println(mname, "argument type not exported:", argType) + log.Println("method", mname, "reply type not a pointer:", replyType) continue } - if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { - log.Println(mname, "reply type not exported:", replyType) + if !isExportedOrLocalType(replyType) { + log.Println("method", mname, "reply type not exported or local:", replyType) continue } // Method needs one out: os.Error. @@ -297,12 +305,6 @@ type InvalidRequest struct{} var invalidRequest = InvalidRequest{} -func _new(t reflect.Type) reflect.Value { - v := reflect.Zero(t) - v.Set(reflect.Zero(t.Elem()).Addr()) - return v -} - func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { resp := server.getResponse() // Encode the response header @@ -411,8 +413,16 @@ func (server *Server) ServeCodec(codec ServerCodec) { } // Decode the argument value. - argv := _new(mtype.ArgType) - replyv := _new(mtype.ReplyType) + var argv reflect.Value + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + replyv := reflect.New(mtype.ReplyType.Elem()) err = codec.ReadRequestBody(argv.Interface()) if err != nil { if err == os.EOF || err == io.ErrUnexpectedEOF { @@ -424,6 +434,9 @@ func (server *Server) ServeCodec(codec ServerCodec) { server.sendResponse(sending, req, replyv.Interface(), codec, err.String()) continue } + if argIsValue { + argv = argv.Elem() + } go service.call(server, sending, mtype, req, argv, replyv, codec) } codec.Close() diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index d4041ae70..cfff0c9ad 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -38,7 +38,9 @@ type Reply struct { type Arith int -func (t *Arith) Add(args *Args, reply *Reply) os.Error { +// Some of Arith's methods have value args, some have pointer args. That's deliberate. + +func (t *Arith) Add(args Args, reply *Reply) os.Error { reply.C = args.A + args.B return nil } @@ -48,7 +50,7 @@ func (t *Arith) Mul(args *Args, reply *Reply) os.Error { return nil } -func (t *Arith) Div(args *Args, reply *Reply) os.Error { +func (t *Arith) Div(args Args, reply *Reply) os.Error { if args.B == 0 { return os.ErrorString("divide by zero") } @@ -61,8 +63,8 @@ func (t *Arith) String(args *Args, reply *string) os.Error { return nil } -func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) { - _, err = fmt.Sscan(*args, &reply.C) +func (t *Arith) Scan(args string, reply *Reply) (err os.Error) { + _, err = fmt.Sscan(args, &reply.C) return } @@ -262,16 +264,11 @@ func testHTTPRPC(t *testing.T, path string) { } } -type ArgNotPointer int type ReplyNotPointer int type ArgNotPublic int type ReplyNotPublic int type local struct{} -func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error { - return nil -} - func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error { return nil } @@ -286,11 +283,7 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error { // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { - err := Register(new(ArgNotPointer)) - if err == nil { - t.Errorf("expected error registering ArgNotPointer") - } - err = Register(new(ReplyNotPointer)) + err := Register(new(ReplyNotPointer)) if err == nil { t.Errorf("expected error registering ReplyNotPointer") } @@ -351,18 +344,26 @@ func testSendDeadlock(client *Client) { client.Call("Arith.Add", args, reply) } -func TestCountMallocs(t *testing.T) { +func dialDirect() (*Client, os.Error) { + return Dial("tcp", serverAddr) +} + +func dialHTTP() (*Client, os.Error) { + return DialHTTP("tcp", httpServerAddr) +} + +func countMallocs(dial func() (*Client, os.Error), t *testing.T) uint64 { once.Do(startServer) - client, err := Dial("tcp", serverAddr) + client, err := dial() if err != nil { - t.Error("error dialing", err) + t.Fatal("error dialing", err) } args := &Args{7, 8} reply := new(Reply) mallocs := 0 - runtime.MemStats.Mallocs const count = 100 for i := 0; i < count; i++ { - err = client.Call("Arith.Add", args, reply) + err := client.Call("Arith.Add", args, reply) if err != nil { t.Errorf("Add: expected no error but got string %q", err.String()) } @@ -371,13 +372,21 @@ func TestCountMallocs(t *testing.T) { } } mallocs += runtime.MemStats.Mallocs - fmt.Printf("mallocs per rpc round trip: %d\n", mallocs/count) + return mallocs / count } -func BenchmarkEndToEnd(b *testing.B) { +func TestCountMallocs(t *testing.T) { + fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) +} + +func TestCountMallocsOverHTTP(t *testing.T) { + fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) +} + +func benchmarkEndToEnd(dial func() (*Client, os.Error), b *testing.B) { b.StopTimer() once.Do(startServer) - client, err := Dial("tcp", serverAddr) + client, err := dial() if err != nil { fmt.Println("error dialing", err) return @@ -399,3 +408,11 @@ func BenchmarkEndToEnd(b *testing.B) { } } } + +func BenchmarkEndToEnd(b *testing.B) { + benchmarkEndToEnd(dialDirect, b) +} + +func BenchmarkEndToEndHTTP(b *testing.B) { + benchmarkEndToEnd(dialHTTP, b) +} diff --git a/src/pkg/runtime/386/asm.s b/src/pkg/runtime/386/asm.s index 598fc6846..e2cabef14 100644 --- a/src/pkg/runtime/386/asm.s +++ b/src/pkg/runtime/386/asm.s @@ -149,7 +149,7 @@ TEXT runtime·gogocall(SB), 7, $0 // void mcall(void (*fn)(G*)) // Switch to m->g0's stack, call fn(g). -// Fn must never return. It should gogo(&g->gobuf) +// Fn must never return. It should gogo(&g->sched) // to keep running g. TEXT runtime·mcall(SB), 7, $0 MOVL fn+0(FP), DI diff --git a/src/pkg/runtime/Makefile b/src/pkg/runtime/Makefile index 4da78c5f0..b122e0599 100644 --- a/src/pkg/runtime/Makefile +++ b/src/pkg/runtime/Makefile @@ -71,7 +71,6 @@ OFILES=\ msize.$O\ print.$O\ proc.$O\ - reflect.$O\ rune.$O\ runtime.$O\ runtime1.$O\ diff --git a/src/pkg/runtime/amd64/asm.s b/src/pkg/runtime/amd64/asm.s index a611985c5..46d82e365 100644 --- a/src/pkg/runtime/amd64/asm.s +++ b/src/pkg/runtime/amd64/asm.s @@ -133,7 +133,7 @@ TEXT runtime·gogocall(SB), 7, $0 // void mcall(void (*fn)(G*)) // Switch to m->g0's stack, call fn(g). -// Fn must never return. It should gogo(&g->gobuf) +// Fn must never return. It should gogo(&g->sched) // to keep running g. TEXT runtime·mcall(SB), 7, $0 MOVQ fn+0(FP), DI diff --git a/src/pkg/runtime/arm/asm.s b/src/pkg/runtime/arm/asm.s index 4d36606a7..63153658f 100644 --- a/src/pkg/runtime/arm/asm.s +++ b/src/pkg/runtime/arm/asm.s @@ -128,7 +128,7 @@ TEXT runtime·gogocall(SB), 7, $-4 // void mcall(void (*fn)(G*)) // Switch to m->g0's stack, call fn(g). -// Fn must never return. It should gogo(&g->gobuf) +// Fn must never return. It should gogo(&g->sched) // to keep running g. TEXT runtime·mcall(SB), 7, $-4 MOVW fn+0(FP), R0 diff --git a/src/pkg/runtime/arm/softfloat.c b/src/pkg/runtime/arm/softfloat.c index f60fab14f..f91a6fc09 100644 --- a/src/pkg/runtime/arm/softfloat.c +++ b/src/pkg/runtime/arm/softfloat.c @@ -91,6 +91,7 @@ static uint32 stepflt(uint32 *pc, uint32 *regs) { uint32 i, regd, regm, regn; + int32 delta; uint32 *addr; uint64 uval; int64 sval; @@ -117,7 +118,7 @@ stepflt(uint32 *pc, uint32 *regs) return 1; } if(i == 0xe08bb00d) { - // add sp to 11. + // add sp to r11. // might be part of a large stack offset address // (or might not, but again no harm done). regs[11] += regs[13]; @@ -134,6 +135,19 @@ stepflt(uint32 *pc, uint32 *regs) runtime·printf("*** fpsr R[CPSR] = F[CPSR] %x\n", regs[CPSR]); return 1; } + if((i&0xff000000) == 0xea000000) { + // unconditional branch + // can happen in the middle of floating point + // if the linker decides it is time to lay down + // a sequence of instruction stream constants. + delta = i&0xffffff; + delta = (delta<<8) >> 8; // sign extend + + if(trace) + runtime·printf("*** cpu PC += %x\n", (delta+2)*4); + return delta+2; + } + goto stage1; stage1: // load/store regn is cpureg, regm is 8bit offset @@ -489,8 +503,10 @@ runtime·_sfloat2(uint32 *lr, uint32 r0) uint32 skip; skip = stepflt(lr, &r0); - if(skip == 0) + if(skip == 0) { + runtime·printf("sfloat2 %p %x\n", lr, *lr); fabort(); // not ok to fail first instruction + } lr += skip; while(skip = stepflt(lr, &r0)) diff --git a/src/pkg/runtime/chan.c b/src/pkg/runtime/chan.c index 8c45b076d..f94c3ef40 100644 --- a/src/pkg/runtime/chan.c +++ b/src/pkg/runtime/chan.c @@ -9,7 +9,6 @@ static int32 debug = 0; -typedef struct Link Link; typedef struct WaitQ WaitQ; typedef struct SudoG SudoG; typedef struct Select Select; @@ -51,12 +50,6 @@ struct Hchan // chanbuf(c, i) is pointer to the i'th slot in the buffer. #define chanbuf(c, i) ((byte*)((c)+1)+(uintptr)(c)->elemsize*(i)) -struct Link -{ - Link* link; // asynch queue circular linked list - byte elem[8]; // asynch queue data element (+ more) -}; - enum { // Scase.kind @@ -121,7 +114,6 @@ runtime·makechan_c(Type *elem, int64 hint) by = runtime·mal(n + hint*elem->size); c = (Hchan*)by; - by += n; runtime·addfinalizer(c, destroychan, 0); c->elemsize = elem->size; @@ -136,6 +128,15 @@ runtime·makechan_c(Type *elem, int64 hint) return c; } +// For reflect +// func makechan(typ *ChanType, size uint32) (chan) +void +reflect·makechan(ChanType *t, uint32 size, Hchan *c) +{ + c = runtime·makechan_c(t->elem, size); + FLUSH(&c); +} + static void destroychan(Hchan *c) { @@ -271,6 +272,7 @@ closed: runtime·panicstring("send on closed channel"); } + void runtime·chanrecv(Hchan* c, byte *ep, bool *selected, bool *received) { @@ -527,6 +529,71 @@ runtime·selectnbrecv2(byte *v, bool *received, Hchan *c, bool selected) runtime·chanrecv(c, v, &selected, received); } +// For reflect: +// func chansend(c chan, val iword, nb bool) (selected bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +// +// The "uintptr selected" is really "bool selected" but saying +// uintptr gets us the right alignment for the output parameter block. +void +reflect·chansend(Hchan *c, uintptr val, bool nb, uintptr selected) +{ + bool *sp; + byte *vp; + + if(c == nil) + runtime·panicstring("send to nil channel"); + + if(nb) { + selected = false; + sp = (bool*)&selected; + } else { + *(bool*)&selected = true; + FLUSH(&selected); + sp = nil; + } + if(c->elemsize <= sizeof(val)) + vp = (byte*)&val; + else + vp = (byte*)val; + runtime·chansend(c, vp, sp); +} + +// For reflect: +// func chanrecv(c chan, nb bool) (val iword, selected, received bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·chanrecv(Hchan *c, bool nb, uintptr val, bool selected, bool received) +{ + byte *vp; + bool *sp; + + if(c == nil) + runtime·panicstring("receive from nil channel"); + + if(nb) { + selected = false; + sp = &selected; + } else { + selected = true; + FLUSH(&selected); + sp = nil; + } + received = false; + FLUSH(&received); + if(c->elemsize <= sizeof(val)) { + val = 0; + vp = (byte*)&val; + } else { + vp = runtime·mal(c->elemsize); + val = (uintptr)vp; + FLUSH(&val); + } + runtime·chanrecv(c, vp, sp, &received); +} + static void newselect(int32, Select**); // newselect(size uint32) (sel *byte); @@ -1044,22 +1111,36 @@ runtime·closechan(Hchan *c) runtime·unlock(c); } +// For reflect +// func chanclose(c chan) void -runtime·chanclose(Hchan *c) +reflect·chanclose(Hchan *c) { runtime·closechan(c); } -int32 -runtime·chanlen(Hchan *c) +// For reflect +// func chanlen(c chan) (len int32) +void +reflect·chanlen(Hchan *c, int32 len) { - return c->qcount; + if(c == nil) + len = 0; + else + len = c->qcount; + FLUSH(&len); } -int32 -runtime·chancap(Hchan *c) +// For reflect +// func chancap(c chan) (cap int32) +void +reflect·chancap(Hchan *c, int32 cap) { - return c->dataqsiz; + if(c == nil) + cap = 0; + else + cap = c->dataqsiz; + FLUSH(&cap); } static SudoG* diff --git a/src/pkg/runtime/darwin/386/signal.c b/src/pkg/runtime/darwin/386/signal.c index 35bbb178b..29170b669 100644 --- a/src/pkg/runtime/darwin/386/signal.c +++ b/src/pkg/runtime/darwin/386/signal.c @@ -185,3 +185,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/darwin/386/sys.s b/src/pkg/runtime/darwin/386/sys.s index 08eca9d5a..87fbdbb79 100644 --- a/src/pkg/runtime/darwin/386/sys.s +++ b/src/pkg/runtime/darwin/386/sys.s @@ -33,6 +33,16 @@ TEXT runtime·write(SB),7,$0 INT $0x80 RET +TEXT runtime·raisesigpipe(SB),7,$8 + get_tls(CX) + MOVL m(CX), DX + MOVL m_procid(DX), DX + MOVL DX, 0(SP) // thread_port + MOVL $13, 4(SP) // signal: SIGPIPE + MOVL $328, AX // __pthread_kill + INT $0x80 + RET + TEXT runtime·mmap(SB),7,$0 MOVL $197, AX INT $0x80 diff --git a/src/pkg/runtime/darwin/amd64/signal.c b/src/pkg/runtime/darwin/amd64/signal.c index 3a99d2308..036a3aca7 100644 --- a/src/pkg/runtime/darwin/amd64/signal.c +++ b/src/pkg/runtime/darwin/amd64/signal.c @@ -195,3 +195,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/darwin/amd64/sys.s b/src/pkg/runtime/darwin/amd64/sys.s index 39398e065..8d1b20f11 100644 --- a/src/pkg/runtime/darwin/amd64/sys.s +++ b/src/pkg/runtime/darwin/amd64/sys.s @@ -38,6 +38,15 @@ TEXT runtime·write(SB),7,$0 SYSCALL RET +TEXT runtime·raisesigpipe(SB),7,$24 + get_tls(CX) + MOVQ m(CX), DX + MOVL $13, DI // arg 1 SIGPIPE + MOVQ m_procid(DX), SI // arg 2 thread_port + MOVL $(0x2000000+328), AX // syscall entry __pthread_kill + SYSCALL + RET + TEXT runtime·setitimer(SB), 7, $0 MOVL 8(SP), DI MOVQ 16(SP), SI diff --git a/src/pkg/runtime/darwin/mem.c b/src/pkg/runtime/darwin/mem.c index cbae18718..935c032bc 100644 --- a/src/pkg/runtime/darwin/mem.c +++ b/src/pkg/runtime/darwin/mem.c @@ -36,6 +36,11 @@ runtime·SysReserve(void *v, uintptr n) return runtime·mmap(v, n, PROT_NONE, MAP_ANON|MAP_PRIVATE, -1, 0); } +enum +{ + ENOMEM = 12, +}; + void runtime·SysMap(void *v, uintptr n) { @@ -43,6 +48,8 @@ runtime·SysMap(void *v, uintptr n) mstats.sys += n; p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_FIXED|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) runtime·throw("runtime: cannot map pages in arena address space"); } diff --git a/src/pkg/runtime/darwin/os.h b/src/pkg/runtime/darwin/os.h index 339768e51..db3c2e8a7 100644 --- a/src/pkg/runtime/darwin/os.h +++ b/src/pkg/runtime/darwin/os.h @@ -27,3 +27,5 @@ void runtime·sigaltstack(struct StackT*, struct StackT*); void runtime·sigtramp(void); void runtime·sigpanic(void); void runtime·setitimer(int32, Itimerval*, Itimerval*); + +void runtime·raisesigpipe(void); diff --git a/src/pkg/runtime/debug/stack.go b/src/pkg/runtime/debug/stack.go index e7d56ac23..e5fae632b 100644 --- a/src/pkg/runtime/debug/stack.go +++ b/src/pkg/runtime/debug/stack.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The debug package contains facilities for programs to debug themselves -// while they are running. +// Package debug contains facilities for programs to debug themselves while +// they are running. package debug import ( diff --git a/src/pkg/runtime/extern.go b/src/pkg/runtime/extern.go index c6e664abb..9da3423c6 100644 --- a/src/pkg/runtime/extern.go +++ b/src/pkg/runtime/extern.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The runtime package contains operations that interact with Go's runtime system, + Package runtime contains operations that interact with Go's runtime system, such as functions to control goroutines. It also includes the low-level type information used by the reflect package; see reflect's documentation for the programmable interface to the run-time type system. diff --git a/src/pkg/runtime/freebsd/386/signal.c b/src/pkg/runtime/freebsd/386/signal.c index 1ae2554eb..3600f0762 100644 --- a/src/pkg/runtime/freebsd/386/signal.c +++ b/src/pkg/runtime/freebsd/386/signal.c @@ -182,3 +182,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/freebsd/386/sys.s b/src/pkg/runtime/freebsd/386/sys.s index c4715b668..765e2fcc4 100644 --- a/src/pkg/runtime/freebsd/386/sys.s +++ b/src/pkg/runtime/freebsd/386/sys.s @@ -60,6 +60,20 @@ TEXT runtime·write(SB),7,$-4 INT $0x80 RET +TEXT runtime·raisesigpipe(SB),7,$12 + // thr_self(&8(SP)) + LEAL 8(SP), AX + MOVL AX, 0(SP) + MOVL $432, AX + INT $0x80 + // thr_kill(self, SIGPIPE) + MOVL 8(SP), AX + MOVL AX, 0(SP) + MOVL $13, 4(SP) + MOVL $433, AX + INT $0x80 + RET + TEXT runtime·notok(SB),7,$0 MOVL $0xf1, 0xf1 RET diff --git a/src/pkg/runtime/freebsd/amd64/signal.c b/src/pkg/runtime/freebsd/amd64/signal.c index 9d8e5e692..85cb1d855 100644 --- a/src/pkg/runtime/freebsd/amd64/signal.c +++ b/src/pkg/runtime/freebsd/amd64/signal.c @@ -190,3 +190,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/freebsd/amd64/sys.s b/src/pkg/runtime/freebsd/amd64/sys.s index 9a6fdf1ac..c5cc082e4 100644 --- a/src/pkg/runtime/freebsd/amd64/sys.s +++ b/src/pkg/runtime/freebsd/amd64/sys.s @@ -65,6 +65,18 @@ TEXT runtime·write(SB),7,$-8 SYSCALL RET +TEXT runtime·raisesigpipe(SB),7,$16 + // thr_self(&8(SP)) + LEAQ 8(SP), DI // arg 1 &8(SP) + MOVL $432, AX + SYSCALL + // thr_kill(self, SIGPIPE) + MOVQ 8(SP), DI // arg 1 id + MOVQ $13, SI // arg 2 SIGPIPE + MOVL $433, AX + SYSCALL + RET + TEXT runtime·setitimer(SB), 7, $-8 MOVL 8(SP), DI MOVQ 16(SP), SI diff --git a/src/pkg/runtime/freebsd/mem.c b/src/pkg/runtime/freebsd/mem.c index f80439e38..07abf2cfe 100644 --- a/src/pkg/runtime/freebsd/mem.c +++ b/src/pkg/runtime/freebsd/mem.c @@ -42,6 +42,11 @@ runtime·SysReserve(void *v, uintptr n) return runtime·mmap(v, n, PROT_NONE, MAP_ANON|MAP_PRIVATE, -1, 0); } +enum +{ + ENOMEM = 12, +}; + void runtime·SysMap(void *v, uintptr n) { @@ -52,6 +57,8 @@ runtime·SysMap(void *v, uintptr n) // On 64-bit, we don't actually have v reserved, so tread carefully. if(sizeof(void*) == 8) { p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) { runtime·printf("runtime: address space conflict: map(%p) = %p\n", v, p); runtime·throw("runtime: address space conflict"); @@ -60,6 +67,8 @@ runtime·SysMap(void *v, uintptr n) } p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_FIXED|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) runtime·throw("runtime: cannot map pages in arena address space"); } diff --git a/src/pkg/runtime/freebsd/os.h b/src/pkg/runtime/freebsd/os.h index 13754688b..007856c6b 100644 --- a/src/pkg/runtime/freebsd/os.h +++ b/src/pkg/runtime/freebsd/os.h @@ -8,3 +8,5 @@ struct sigaction; void runtime·sigaction(int32, struct sigaction*, struct sigaction*); void runtiem·setitimerval(int32, Itimerval*, Itimerval*); void runtime·setitimer(int32, Itimerval*, Itimerval*); + +void runtime·raisesigpipe(void); diff --git a/src/pkg/runtime/hashmap.c b/src/pkg/runtime/hashmap.c index e50cefd9a..5ba1eb20a 100644 --- a/src/pkg/runtime/hashmap.c +++ b/src/pkg/runtime/hashmap.c @@ -776,6 +776,15 @@ runtime·makemap(Type *key, Type *val, int64 hint, Hmap *ret) FLUSH(&ret); } +// For reflect: +// func makemap(Type *mapType) (hmap *map) +void +reflect·makemap(MapType *t, Hmap *ret) +{ + ret = runtime·makemap_c(t->key, t->elem, 0); + FLUSH(&ret); +} + void runtime·mapaccess(Hmap *h, byte *ak, byte *av, bool *pres) { @@ -855,6 +864,34 @@ runtime·mapaccess2(Hmap *h, ...) } } +// For reflect: +// func mapaccess(h map, key iword) (val iword, pres bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·mapaccess(Hmap *h, uintptr key, uintptr val, bool pres) +{ + byte *ak, *av; + + if(h == nil) + runtime·panicstring("lookup in nil map"); + if(h->keysize <= sizeof(key)) + ak = (byte*)&key; + else + ak = (byte*)key; + val = 0; + pres = false; + if(h->valsize <= sizeof(val)) + av = (byte*)&val; + else { + av = runtime·mal(h->valsize); + val = (uintptr)av; + } + runtime·mapaccess(h, ak, av, &pres); + FLUSH(&val); + FLUSH(&pres); +} + void runtime·mapassign(Hmap *h, byte *ak, byte *av) { @@ -938,6 +975,30 @@ runtime·mapassign2(Hmap *h, ...) } } +// For reflect: +// func mapassign(h map, key, val iword, pres bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·mapassign(Hmap *h, uintptr key, uintptr val, bool pres) +{ + byte *ak, *av; + + if(h == nil) + runtime·panicstring("lookup in nil map"); + if(h->keysize <= sizeof(key)) + ak = (byte*)&key; + else + ak = (byte*)key; + if(h->valsize <= sizeof(val)) + av = (byte*)&val; + else + av = (byte*)val; + if(!pres) + av = nil; + runtime·mapassign(h, ak, av); +} + // mapiterinit(hmap *map[any]any, hiter *any); void runtime·mapiterinit(Hmap *h, struct hash_iter *it) @@ -959,14 +1020,14 @@ runtime·mapiterinit(Hmap *h, struct hash_iter *it) } } -struct hash_iter* -runtime·newmapiterinit(Hmap *h) +// For reflect: +// func mapiterinit(h map) (it iter) +void +reflect·mapiterinit(Hmap *h, struct hash_iter *it) { - struct hash_iter *it; - it = runtime·mal(sizeof *it); + FLUSH(&it); runtime·mapiterinit(h, it); - return it; } // mapiternext(hiter *any); @@ -986,6 +1047,14 @@ runtime·mapiternext(struct hash_iter *it) } } +// For reflect: +// func mapiternext(it iter) +void +reflect·mapiternext(struct hash_iter *it) +{ + runtime·mapiternext(it); +} + // mapiter1(hiter *any) (key any); #pragma textflag 7 void @@ -1026,6 +1095,48 @@ runtime·mapiterkey(struct hash_iter *it, void *ak) return true; } +// For reflect: +// func mapiterkey(h map) (key iword, ok bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·mapiterkey(struct hash_iter *it, uintptr key, bool ok) +{ + Hmap *h; + byte *res; + + key = 0; + ok = false; + h = it->h; + res = it->data; + if(res == nil) { + key = 0; + ok = false; + } else { + key = 0; + if(h->keysize <= sizeof(key)) + h->keyalg->copy(h->keysize, (byte*)&key, res); + else + key = (uintptr)res; + ok = true; + } + FLUSH(&key); + FLUSH(&ok); +} + +// For reflect: +// func maplen(h map) (len int32) +// Like len(m) in the actual language, we treat the nil map as length 0. +void +reflect·maplen(Hmap *h, int32 len) +{ + if(h == nil) + len = 0; + else + len = h->count; + FLUSH(&len); +} + // mapiter2(hiter *any) (key any, val any); #pragma textflag 7 void diff --git a/src/pkg/runtime/iface.c b/src/pkg/runtime/iface.c index 698aead3d..b1015f695 100644 --- a/src/pkg/runtime/iface.c +++ b/src/pkg/runtime/iface.c @@ -6,6 +6,14 @@ #include "type.h" #include "malloc.h" +enum +{ + // If an empty interface has these bits set in its type + // pointer, it was copied from a reflect.Value and is + // not a valid empty interface. + reflectFlags = 3, +}; + void runtime·printiface(Iface i) { @@ -42,7 +50,7 @@ itab(InterfaceType *inter, Type *type, int32 canfail) Method *t, *et; IMethod *i, *ei; uint32 h; - String *iname; + String *iname, *ipkgPath; Itab *m; UncommonType *x; Type *itype; @@ -112,6 +120,7 @@ search: for(; i < ei; i++) { itype = i->type; iname = i->name; + ipkgPath = i->pkgPath; for(;; t++) { if(t >= et) { if(!canfail) { @@ -128,7 +137,7 @@ search: m->bad = 1; goto out; } - if(t->mtyp == itype && t->name == iname) + if(t->mtyp == itype && t->name == iname && t->pkgPath == ipkgPath) break; } if(m) @@ -276,6 +285,8 @@ runtime·assertE2T(Type *t, Eface e, ...) { byte *ret; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); ret = (byte*)(&e+1); assertE2Tret(t, e, ret); } @@ -285,6 +296,8 @@ assertE2Tret(Type *t, Eface e, byte *ret) { Eface err; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { runtime·newTypeAssertionError(nil, nil, t, nil, nil, t->string, @@ -309,6 +322,8 @@ runtime·assertE2T2(Type *t, Eface e, ...) bool *ok; int32 wid; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); ret = (byte*)(&e+1); wid = t->size; ok = (bool*)(ret+runtime·rnd(wid, 1)); @@ -444,6 +459,8 @@ runtime·ifaceE2I(InterfaceType *inter, Eface e, Iface *ret) Type *t; Eface err; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); t = e.type; if(t == nil) { // explicit conversions require non-nil interface value. @@ -456,6 +473,14 @@ runtime·ifaceE2I(InterfaceType *inter, Eface e, Iface *ret) ret->tab = itab(inter, t, 0); } +// For reflect +// func ifaceE2I(t *InterfaceType, e interface{}, dst *Iface) +void +reflect·ifaceE2I(InterfaceType *inter, Eface e, Iface *dst) +{ + runtime·ifaceE2I(inter, e, dst); +} + // func ifaceE2I(sigi *byte, iface any) (ret any) void runtime·assertE2I(InterfaceType* inter, Eface e, Iface ret) @@ -467,6 +492,8 @@ runtime·assertE2I(InterfaceType* inter, Eface e, Iface ret) void runtime·assertE2I2(InterfaceType *inter, Eface e, Iface ret, bool ok) { + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { ok = 0; ret.data = nil; @@ -489,6 +516,8 @@ runtime·assertE2E(InterfaceType* inter, Eface e, Eface ret) Type *t; Eface err; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); t = e.type; if(t == nil) { // explicit conversions require non-nil interface value. @@ -505,6 +534,8 @@ runtime·assertE2E(InterfaceType* inter, Eface e, Eface ret) void runtime·assertE2E2(InterfaceType* inter, Eface e, Eface ret, bool ok) { + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); USED(inter); ret = e; ok = e.type != nil; @@ -582,6 +613,10 @@ runtime·ifaceeq_c(Iface i1, Iface i2) bool runtime·efaceeq_c(Eface e1, Eface e2) { + if(((uintptr)e1.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + if(((uintptr)e2.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e1.type != e2.type) return false; if(e1.type == nil) @@ -624,6 +659,8 @@ runtime·efacethash(Eface e1, uint32 ret) { Type *t; + if(((uintptr)e1.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); ret = 0; t = e1.type; if(t != nil) @@ -634,11 +671,14 @@ runtime·efacethash(Eface e1, uint32 ret) void unsafe·Typeof(Eface e, Eface ret) { + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { ret.type = nil; ret.data = nil; - } else - ret = *(Eface*)e.type; + } else { + ret = *(Eface*)(e.type); + } FLUSH(&ret); } @@ -648,6 +688,8 @@ unsafe·Reflect(Eface e, Eface rettype, void *retaddr) uintptr *p; uintptr x; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { rettype.type = nil; rettype.data = nil; @@ -678,6 +720,9 @@ unsafe·Reflect(Eface e, Eface rettype, void *retaddr) void unsafe·Unreflect(Eface typ, void *addr, Eface e) { + if(((uintptr)typ.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + // Reflect library has reinterpreted typ // as its own kind of type structure. // We know that the pointer to the original @@ -702,6 +747,9 @@ unsafe·New(Eface typ, void *ret) { Type *t; + if(((uintptr)typ.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + // Reflect library has reinterpreted typ // as its own kind of type structure. // We know that the pointer to the original @@ -721,6 +769,9 @@ unsafe·NewArray(Eface typ, uint32 n, void *ret) uint64 size; Type *t; + if(((uintptr)typ.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + // Reflect library has reinterpreted typ // as its own kind of type structure. // We know that the pointer to the original diff --git a/src/pkg/runtime/linux/386/signal.c b/src/pkg/runtime/linux/386/signal.c index 9b72ecbae..8916e10bd 100644 --- a/src/pkg/runtime/linux/386/signal.c +++ b/src/pkg/runtime/linux/386/signal.c @@ -175,3 +175,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/linux/386/sys.s b/src/pkg/runtime/linux/386/sys.s index c39ce253f..868a0d901 100644 --- a/src/pkg/runtime/linux/386/sys.s +++ b/src/pkg/runtime/linux/386/sys.s @@ -30,6 +30,14 @@ TEXT runtime·write(SB),7,$0 INT $0x80 RET +TEXT runtime·raisesigpipe(SB),7,$12 + MOVL $224, AX // syscall - gettid + INT $0x80 + MOVL AX, 0(SP) // arg 1 tid + MOVL $13, 4(SP) // arg 2 SIGPIPE + MOVL $238, AX // syscall - tkill + INT $0x80 + RET TEXT runtime·setitimer(SB),7,$0-24 MOVL $104, AX // syscall - setitimer diff --git a/src/pkg/runtime/linux/amd64/signal.c b/src/pkg/runtime/linux/amd64/signal.c index 1db9c95e5..ee90271ed 100644 --- a/src/pkg/runtime/linux/amd64/signal.c +++ b/src/pkg/runtime/linux/amd64/signal.c @@ -185,3 +185,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/linux/amd64/sys.s b/src/pkg/runtime/linux/amd64/sys.s index 11df1f894..eadd30005 100644 --- a/src/pkg/runtime/linux/amd64/sys.s +++ b/src/pkg/runtime/linux/amd64/sys.s @@ -36,6 +36,15 @@ TEXT runtime·write(SB),7,$0-24 SYSCALL RET +TEXT runtime·raisesigpipe(SB),7,$12 + MOVL $186, AX // syscall - gettid + SYSCALL + MOVL AX, DI // arg 1 tid + MOVL $13, SI // arg 2 SIGPIPE + MOVL $200, AX // syscall - tkill + SYSCALL + RET + TEXT runtime·setitimer(SB),7,$0-24 MOVL 8(SP), DI MOVQ 16(SP), SI diff --git a/src/pkg/runtime/linux/arm/signal.c b/src/pkg/runtime/linux/arm/signal.c index 05c6b0261..88a84d112 100644 --- a/src/pkg/runtime/linux/arm/signal.c +++ b/src/pkg/runtime/linux/arm/signal.c @@ -180,3 +180,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/linux/arm/sys.s b/src/pkg/runtime/linux/arm/sys.s index b9767a028..d866b0e22 100644 --- a/src/pkg/runtime/linux/arm/sys.s +++ b/src/pkg/runtime/linux/arm/sys.s @@ -22,11 +22,12 @@ #define SYS_rt_sigaction (SYS_BASE + 174) #define SYS_sigaltstack (SYS_BASE + 186) #define SYS_mmap2 (SYS_BASE + 192) -#define SYS_gettid (SYS_BASE + 224) #define SYS_futex (SYS_BASE + 240) #define SYS_exit_group (SYS_BASE + 248) #define SYS_munmap (SYS_BASE + 91) #define SYS_setitimer (SYS_BASE + 104) +#define SYS_gettid (SYS_BASE + 224) +#define SYS_tkill (SYS_BASE + 238) #define ARM_BASE (SYS_BASE + 0x0f0000) #define SYS_ARM_cacheflush (ARM_BASE + 2) @@ -55,6 +56,15 @@ TEXT runtime·exit1(SB),7,$-4 MOVW $1003, R1 MOVW R0, (R1) // fail hard +TEXT runtime·raisesigpipe(SB),7,$-4 + MOVW $SYS_gettid, R7 + SWI $0 + // arg 1 tid already in R0 from gettid + MOVW $13, R1 // arg 2 SIGPIPE + MOVW $SYS_tkill, R7 + SWI $0 + RET + TEXT runtime·mmap(SB),7,$0 MOVW 0(FP), R0 MOVW 4(FP), R1 diff --git a/src/pkg/runtime/linux/mem.c b/src/pkg/runtime/linux/mem.c index d2f6f8204..ce1a8aa70 100644 --- a/src/pkg/runtime/linux/mem.c +++ b/src/pkg/runtime/linux/mem.c @@ -48,6 +48,11 @@ runtime·SysReserve(void *v, uintptr n) return runtime·mmap(v, n, PROT_NONE, MAP_ANON|MAP_PRIVATE, -1, 0); } +enum +{ + ENOMEM = 12, +}; + void runtime·SysMap(void *v, uintptr n) { @@ -66,6 +71,8 @@ runtime·SysMap(void *v, uintptr n) } p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_FIXED|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) runtime·throw("runtime: cannot map pages in arena address space"); } diff --git a/src/pkg/runtime/linux/os.h b/src/pkg/runtime/linux/os.h index 6ae088977..0bb8d0339 100644 --- a/src/pkg/runtime/linux/os.h +++ b/src/pkg/runtime/linux/os.h @@ -15,3 +15,5 @@ void runtime·rt_sigaction(uintptr, struct Sigaction*, void*, uintptr); void runtime·sigaltstack(Sigaltstack*, Sigaltstack*); void runtime·sigpanic(void); void runtime·setitimer(int32, Itimerval*, Itimerval*); + +void runtime·raisesigpipe(void); diff --git a/src/pkg/runtime/malloc.goc b/src/pkg/runtime/malloc.goc index 41060682e..1f2d6da40 100644 --- a/src/pkg/runtime/malloc.goc +++ b/src/pkg/runtime/malloc.goc @@ -346,7 +346,7 @@ runtime·MHeap_SysAlloc(MHeap *h, uintptr n) return nil; if(p < h->arena_start || p+n - h->arena_start >= MaxArena32) { - runtime·printf("runtime: memory allocated by OS not in usable range"); + runtime·printf("runtime: memory allocated by OS not in usable range\n"); runtime·SysFree(p, n); return nil; } diff --git a/src/pkg/runtime/mcache.c b/src/pkg/runtime/mcache.c index 0f41a0ebc..e40621186 100644 --- a/src/pkg/runtime/mcache.c +++ b/src/pkg/runtime/mcache.c @@ -22,6 +22,8 @@ runtime·MCache_Alloc(MCache *c, int32 sizeclass, uintptr size, int32 zeroed) // Replenish using central lists. n = runtime·MCentral_AllocList(&runtime·mheap.central[sizeclass], runtime·class_to_transfercount[sizeclass], &first); + if(n == 0) + runtime·throw("out of memory"); l->list = first; l->nlist = n; c->size += n*size; diff --git a/src/pkg/runtime/mgc0.c b/src/pkg/runtime/mgc0.c index 14d485b71..ac6a1fa40 100644 --- a/src/pkg/runtime/mgc0.c +++ b/src/pkg/runtime/mgc0.c @@ -6,6 +6,7 @@ #include "runtime.h" #include "malloc.h" +#include "stack.h" enum { Debug = 0, @@ -92,6 +93,11 @@ scanblock(byte *b, int64 n) void **bw, **w, **ew; Workbuf *wbuf; + if((int64)(uintptr)n != n || n < 0) { + runtime·printf("scanblock %p %D\n", b, n); + runtime·throw("scanblock"); + } + // Memory arena parameters. arena_start = runtime·mheap.arena_start; @@ -323,20 +329,46 @@ getfull(Workbuf *b) static void scanstack(G *gp) { + int32 n; Stktop *stk; - byte *sp; + byte *sp, *guard; + + stk = (Stktop*)gp->stackbase; + guard = gp->stackguard; - if(gp == g) + if(gp == g) { + // Scanning our own stack: start at &gp. sp = (byte*)&gp; - else + } else { + // Scanning another goroutine's stack. + // The goroutine is usually asleep (the world is stopped). sp = gp->sched.sp; + + // The exception is that if the goroutine is about to enter or might + // have just exited a system call, it may be executing code such + // as schedlock and may have needed to start a new stack segment. + // Use the stack segment and stack pointer at the time of + // the system call instead, since that won't change underfoot. + if(gp->gcstack != nil) { + stk = (Stktop*)gp->gcstack; + sp = gp->gcsp; + guard = gp->gcguard; + } + } + if(Debug > 1) runtime·printf("scanstack %d %p\n", gp->goid, sp); - stk = (Stktop*)gp->stackbase; + n = 0; while(stk) { + if(sp < guard-StackGuard || (byte*)stk < sp) { + runtime·printf("scanstack inconsistent: g%d#%d sp=%p not in [%p,%p]\n", gp->goid, n, sp, guard-StackGuard, stk); + runtime·throw("scanstack"); + } scanblock(sp, (byte*)stk - sp); sp = stk->gobuf.sp; + guard = stk->stackguard; stk = (Stktop*)stk->stackbase; + n++; } } diff --git a/src/pkg/runtime/mheap.c b/src/pkg/runtime/mheap.c index 8061b7cf8..dde31ce34 100644 --- a/src/pkg/runtime/mheap.c +++ b/src/pkg/runtime/mheap.c @@ -180,9 +180,7 @@ MHeap_Grow(MHeap *h, uintptr npage) // Allocate a multiple of 64kB (16 pages). npage = (npage+15)&~15; ask = npage<<PageShift; - if(ask > h->arena_end - h->arena_used) - return false; - if(ask < HeapAllocChunk && HeapAllocChunk <= h->arena_end - h->arena_used) + if(ask < HeapAllocChunk) ask = HeapAllocChunk; v = runtime·MHeap_SysAlloc(h, ask); @@ -191,8 +189,10 @@ MHeap_Grow(MHeap *h, uintptr npage) ask = npage<<PageShift; v = runtime·MHeap_SysAlloc(h, ask); } - if(v == nil) + if(v == nil) { + runtime·printf("runtime: out of memory: cannot allocate %D-byte block (%D in use)\n", (uint64)ask, mstats.heap_sys); return false; + } } mstats.heap_sys += ask; diff --git a/src/pkg/runtime/mkversion.c b/src/pkg/runtime/mkversion.c index 56afa1892..0d96aa356 100644 --- a/src/pkg/runtime/mkversion.c +++ b/src/pkg/runtime/mkversion.c @@ -4,7 +4,7 @@ char *template = "// generated by mkversion.c; do not edit.\n" "package runtime\n" - "const defaultGoroot = \"%s\"\n" + "const defaultGoroot = `%s`\n" "const theVersion = \"%s\"\n"; void diff --git a/src/pkg/runtime/plan9/mem.c b/src/pkg/runtime/plan9/mem.c index b840de984..9dfdf2cc3 100644 --- a/src/pkg/runtime/plan9/mem.c +++ b/src/pkg/runtime/plan9/mem.c @@ -4,6 +4,7 @@ #include "runtime.h" #include "malloc.h" +#include "os.h" extern byte end[]; static byte *bloc = { end }; @@ -52,5 +53,6 @@ runtime·SysMap(void *v, uintptr nbytes) void* runtime·SysReserve(void *v, uintptr nbytes) { + USED(v); return runtime·SysAlloc(nbytes); } diff --git a/src/pkg/runtime/plan9/thread.c b/src/pkg/runtime/plan9/thread.c index fa96552a9..7c6ca45a3 100644 --- a/src/pkg/runtime/plan9/thread.c +++ b/src/pkg/runtime/plan9/thread.c @@ -138,3 +138,8 @@ runtime·notewakeup(Note *n) runtime·usemrelease(&n->sema); } +void +os·sigpipe(void) +{ + runtime·throw("too many writes on closed pipe"); +} diff --git a/src/pkg/runtime/proc.c b/src/pkg/runtime/proc.c index e212c7820..52784854f 100644 --- a/src/pkg/runtime/proc.c +++ b/src/pkg/runtime/proc.c @@ -590,6 +590,9 @@ schedule(G *gp) // re-queues g and runs everyone else who is waiting // before running g again. If g->status is Gmoribund, // kills off g. +// Cannot split stack because it is called from exitsyscall. +// See comment below. +#pragma textflag 7 void runtime·gosched(void) { @@ -604,19 +607,17 @@ runtime·gosched(void) // Record that it's not using the cpu anymore. // This is called only from the go syscall library and cgocall, // not from the low-level system calls used by the runtime. +// // Entersyscall cannot split the stack: the runtime·gosave must -// make g->sched refer to the caller's stack pointer. +// make g->sched refer to the caller's stack segment, because +// entersyscall is going to return immediately after. // It's okay to call matchmg and notewakeup even after // decrementing mcpu, because we haven't released the -// sched lock yet. +// sched lock yet, so the garbage collector cannot be running. #pragma textflag 7 void runtime·entersyscall(void) { - // Leave SP around for gc and traceback. - // Do before notewakeup so that gc - // never sees Gsyscall with wrong stack. - runtime·gosave(&g->sched); if(runtime·sched.predawn) return; schedlock(); @@ -625,10 +626,23 @@ runtime·entersyscall(void) runtime·sched.msyscall++; if(runtime·sched.gwait != 0) matchmg(); + if(runtime·sched.waitstop && runtime·sched.mcpu <= runtime·sched.mcpumax) { runtime·sched.waitstop = 0; runtime·notewakeup(&runtime·sched.stopped); } + + // Leave SP around for gc and traceback. + // Do before schedunlock so that gc + // never sees Gsyscall with wrong stack. + runtime·gosave(&g->sched); + g->gcsp = g->sched.sp; + g->gcstack = g->stackbase; + g->gcguard = g->stackguard; + if(g->gcsp < g->gcguard-StackGuard || g->gcstack < g->gcsp) { + runtime·printf("entersyscall inconsistent %p [%p,%p]\n", g->gcsp, g->gcguard-StackGuard, g->gcstack); + runtime·throw("entersyscall"); + } schedunlock(); } @@ -647,7 +661,11 @@ runtime·exitsyscall(void) runtime·sched.mcpu++; // Fast path - if there's room for this m, we're done. if(m->profilehz == runtime·sched.profilehz && runtime·sched.mcpu <= runtime·sched.mcpumax) { + // There's a cpu for us, so we can run. g->status = Grunning; + // Garbage collector isn't running (since we are), + // so okay to clear gcstack. + g->gcstack = nil; schedunlock(); return; } @@ -663,6 +681,14 @@ runtime·exitsyscall(void) // When the scheduler takes g away from m, // it will undo the runtime·sched.mcpu++ above. runtime·gosched(); + + // Gosched returned, so we're allowed to run now. + // Delete the gcstack information that we left for + // the garbage collector during the system call. + // Must wait until now because until gosched returns + // we don't know for sure that the garbage collector + // is not running. + g->gcstack = nil; } void @@ -1196,6 +1222,12 @@ runtime·gomaxprocsfunc(int32 n) if (n <= 0) n = ret; runtime·gomaxprocs = n; + if (runtime·gcwaiting != 0) { + if (runtime·sched.mcpumax != 1) + runtime·throw("invalid runtime·sched.mcpumax during gc"); + schedunlock(); + return ret; + } runtime·sched.mcpumax = n; // handle fewer procs? if(runtime·sched.mcpu > runtime·sched.mcpumax) { diff --git a/src/pkg/runtime/proc_test.go b/src/pkg/runtime/proc_test.go new file mode 100644 index 000000000..a15b2d80a --- /dev/null +++ b/src/pkg/runtime/proc_test.go @@ -0,0 +1,43 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package runtime_test + +import ( + "runtime" + "testing" +) + +var stop = make(chan bool, 1) + +func perpetuumMobile() { + select { + case <-stop: + default: + go perpetuumMobile() + } +} + +func TestStopTheWorldDeadlock(t *testing.T) { + if testing.Short() { + t.Logf("skipping during short test") + return + } + runtime.GOMAXPROCS(3) + compl := make(chan int, 1) + go func() { + for i := 0; i != 1000; i += 1 { + runtime.GC() + } + compl <- 0 + }() + go func() { + for i := 0; i != 1000; i += 1 { + runtime.GOMAXPROCS(3) + } + }() + go perpetuumMobile() + <-compl + stop <- true +} diff --git a/src/pkg/runtime/reflect.goc b/src/pkg/runtime/reflect.goc deleted file mode 100644 index 9bdc48afb..000000000 --- a/src/pkg/runtime/reflect.goc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package reflect -#include "runtime.h" -#include "type.h" - -static Type* -gettype(void *typ) -{ - // typ is a *runtime.Type (or *runtime.MapType, etc), but the Type - // defined in type.h includes an interface value header - // in front of the raw structure. the -2 below backs up - // to the interface value header. - return (Type*)((void**)typ - 2); -} - -/* - * Go wrappers around the C functions near the bottom of hashmap.c - * There's no recursion here even though it looks like there is: - * the names after func are in the reflect package name space - * but the names in the C bodies are in the standard C name space. - */ - -func mapaccess(map *byte, key *byte, val *byte) (pres bool) { - runtime·mapaccess((Hmap*)map, key, val, &pres); -} - -func mapassign(map *byte, key *byte, val *byte) { - runtime·mapassign((Hmap*)map, key, val); -} - -func maplen(map *byte) (len int32) { - // length is first word of map - len = *(uint32*)map; -} - -func mapiterinit(map *byte) (it *byte) { - it = (byte*)runtime·newmapiterinit((Hmap*)map); -} - -func mapiternext(it *byte) { - runtime·mapiternext((struct hash_iter*)it); -} - -func mapiterkey(it *byte, key *byte) (ok bool) { - ok = runtime·mapiterkey((struct hash_iter*)it, key); -} - -func makemap(typ *byte) (map *byte) { - MapType *t; - - t = (MapType*)gettype(typ); - map = (byte*)runtime·makemap_c(t->key, t->elem, 0); -} - -/* - * Go wrappers around the C functions in chan.c - */ - -func makechan(typ *byte, size uint32) (ch *byte) { - ChanType *t; - - // typ is a *runtime.ChanType, but the ChanType - // defined in type.h includes an interface value header - // in front of the raw ChanType. the -2 below backs up - // to the interface value header. - t = (ChanType*)gettype(typ); - ch = (byte*)runtime·makechan_c(t->elem, size); -} - -func chansend(ch *byte, val *byte, selected *bool) { - runtime·chansend((Hchan*)ch, val, selected); -} - -func chanrecv(ch *byte, val *byte, selected *bool, received *bool) { - runtime·chanrecv((Hchan*)ch, val, selected, received); -} - -func chanclose(ch *byte) { - runtime·chanclose((Hchan*)ch); -} - -func chanlen(ch *byte) (r int32) { - r = runtime·chanlen((Hchan*)ch); -} - -func chancap(ch *byte) (r int32) { - r = runtime·chancap((Hchan*)ch); -} - - -/* - * Go wrappers around the functions in iface.c - */ - -func setiface(typ *byte, x *byte, ret *byte) { - InterfaceType *t; - - t = (InterfaceType*)gettype(typ); - if(t->mhdr.len == 0) { - // already an empty interface - *(Eface*)ret = *(Eface*)x; - return; - } - if(((Eface*)x)->type == nil) { - // can assign nil to any interface - ((Iface*)ret)->tab = nil; - ((Iface*)ret)->data = nil; - return; - } - runtime·ifaceE2I((InterfaceType*)gettype(typ), *(Eface*)x, (Iface*)ret); -} diff --git a/src/pkg/runtime/runtime-gdb.py b/src/pkg/runtime/runtime-gdb.py index 08772a431..3f767fbdd 100644 --- a/src/pkg/runtime/runtime-gdb.py +++ b/src/pkg/runtime/runtime-gdb.py @@ -122,10 +122,13 @@ class ChanTypePrinter: return str(self.val.type) def children(self): - ptr = self.val['recvdataq'] - for idx in range(self.val["qcount"]): - yield ('[%d]' % idx, ptr['elem']) - ptr = ptr['link'] + # see chan.c chanbuf() + et = [x.type for x in self.val['free'].type.target().fields() if x.name == 'elem'][0] + ptr = (self.val.address + 1).cast(et.pointer()) + for i in range(self.val["qcount"]): + j = (self.val["recvx"] + i) % self.val["dataqsiz"] + yield ('[%d]' % i, (ptr + j).dereference()) + # # Register all the *Printer classes above. diff --git a/src/pkg/runtime/runtime.h b/src/pkg/runtime/runtime.h index 6cf2685fd..f9b404e15 100644 --- a/src/pkg/runtime/runtime.h +++ b/src/pkg/runtime/runtime.h @@ -183,6 +183,9 @@ struct G Defer* defer; Panic* panic; Gobuf sched; + byte* gcstack; // if status==Gsyscall, gcstack = stackbase to use during gc + byte* gcsp; // if status==Gsyscall, gcsp = sched.sp to use during gc + byte* gcguard; // if status==Gsyscall, gcguard = stackguard to use during gc byte* stack0; byte* entry; // initial function G* alllink; // on allg @@ -241,6 +244,7 @@ struct M void* sehframe; #endif }; + struct Stktop { // The offsets of these fields are known to (hard-coded in) libmach. @@ -580,7 +584,6 @@ int32 runtime·gomaxprocsfunc(int32 n); void runtime·mapassign(Hmap*, byte*, byte*); void runtime·mapaccess(Hmap*, byte*, byte*, bool*); -struct hash_iter* runtime·newmapiterinit(Hmap*); void runtime·mapiternext(struct hash_iter*); bool runtime·mapiterkey(struct hash_iter*, void*); void runtime·mapiterkeyvalue(struct hash_iter*, void*, void*); @@ -589,7 +592,6 @@ Hmap* runtime·makemap_c(Type*, Type*, int64); Hchan* runtime·makechan_c(Type*, int64); void runtime·chansend(Hchan*, void*, bool*); void runtime·chanrecv(Hchan*, void*, bool*, bool*); -void runtime·chanclose(Hchan*); int32 runtime·chanlen(Hchan*); int32 runtime·chancap(Hchan*); diff --git a/src/pkg/runtime/symtab.c b/src/pkg/runtime/symtab.c index 6f0eea0e7..da4579734 100644 --- a/src/pkg/runtime/symtab.c +++ b/src/pkg/runtime/symtab.c @@ -291,7 +291,9 @@ splitpcln(void) if(f < ef && pc >= (f+1)->entry) { f->pcln.len = p - f->pcln.array; f->pcln.cap = f->pcln.len; - f++; + do + f++; + while(f < ef && pc >= (f+1)->entry); f->pcln.array = p; // pc0 and ln0 are the starting values for // the loop over f->pcln, so pc must be diff --git a/src/pkg/runtime/type.go b/src/pkg/runtime/type.go index 71ad4e7a5..30f3ec642 100644 --- a/src/pkg/runtime/type.go +++ b/src/pkg/runtime/type.go @@ -117,8 +117,9 @@ type UnsafePointerType commonType // ArrayType represents a fixed array type. type ArrayType struct { commonType - elem *Type // array element type - len uintptr + elem *Type // array element type + slice *Type // slice type + len uintptr } // SliceType represents a slice type. diff --git a/src/pkg/runtime/windows/thread.c b/src/pkg/runtime/windows/thread.c index aedd24200..2ce92dcfb 100644 --- a/src/pkg/runtime/windows/thread.c +++ b/src/pkg/runtime/windows/thread.c @@ -378,3 +378,9 @@ runtime·compilecallback(Eface fn, bool cleanstack) return ret; } + +void +os·sigpipe(void) +{ + runtime·throw("too many writes on closed pipe"); +} diff --git a/src/pkg/scanner/scanner.go b/src/pkg/scanner/scanner.go index ec2266477..e79d392f7 100644 --- a/src/pkg/scanner/scanner.go +++ b/src/pkg/scanner/scanner.go @@ -2,10 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A scanner and tokenizer for UTF-8-encoded text. Takes an io.Reader -// providing the source, which then can be tokenized through repeated calls -// to the Scan function. For compatibility with existing tools, the NUL -// character is not allowed (implementation restriction). +// Package scanner provides a scanner and tokenizer for UTF-8-encoded text. +// It takes an io.Reader providing the source, which then can be tokenized +// through repeated calls to the Scan function. For compatibility with +// existing tools, the NUL character is not allowed (implementation +// restriction). // // By default, a Scanner skips white space and Go comments and recognizes all // literals as defined by the Go language specification. It may be diff --git a/src/pkg/sort/sort.go b/src/pkg/sort/sort.go index c7945d21b..30b1819af 100644 --- a/src/pkg/sort/sort.go +++ b/src/pkg/sort/sort.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The sort package provides primitives for sorting arrays -// and user-defined collections. +// Package sort provides primitives for sorting arrays and user-defined +// collections. package sort // A type, typically a collection, that satisfies sort.Interface can be diff --git a/src/pkg/strconv/atof.go b/src/pkg/strconv/atof.go index 72f162c51..a91e8bfa4 100644 --- a/src/pkg/strconv/atof.go +++ b/src/pkg/strconv/atof.go @@ -2,16 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package strconv implements conversions to and from string representations +// of basic data types. +package strconv + // decimal to binary floating point conversion. // Algorithm: // 1) Store input in multiprecision decimal. // 2) Multiply/divide decimal by powers of two until in range [0.5, 1) // 3) Multiply by 2^precision and round to get mantissa. -// The strconv package implements conversions to and from -// string representations of basic data types. -package strconv - import ( "math" "os" diff --git a/src/pkg/strings/strings.go b/src/pkg/strings/strings.go index 93c7c4647..bfd057180 100644 --- a/src/pkg/strings/strings.go +++ b/src/pkg/strings/strings.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A package of simple functions to manipulate strings. +// Package strings implements simple functions to manipulate strings. package strings import ( diff --git a/src/pkg/sync/mutex.go b/src/pkg/sync/mutex.go index da565d38d..13f03cad3 100644 --- a/src/pkg/sync/mutex.go +++ b/src/pkg/sync/mutex.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The sync package provides basic synchronization primitives -// such as mutual exclusion locks. Other than the Once and -// WaitGroup types, most are intended for use by low-level -// library routines. Higher-level synchronization is better -// done via channels and communication. +// Package sync provides basic synchronization primitives such as mutual +// exclusion locks. Other than the Once and WaitGroup types, most are intended +// for use by low-level library routines. Higher-level synchronization is +// better done via channels and communication. package sync import ( diff --git a/src/pkg/syscall/exec_windows.go b/src/pkg/syscall/exec_windows.go index aeee191dd..85b1c2eda 100644 --- a/src/pkg/syscall/exec_windows.go +++ b/src/pkg/syscall/exec_windows.go @@ -8,6 +8,7 @@ package syscall import ( "sync" + "unsafe" "utf16" ) @@ -217,9 +218,10 @@ func joinExeDirAndFName(dir, p string) (name string, err int) { } type ProcAttr struct { - Dir string - Env []string - Files []int + Dir string + Env []string + Files []int + HideWindow bool } var zeroAttributes ProcAttr @@ -279,8 +281,12 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid, handle int, } } si := new(StartupInfo) - GetStartupInfo(si) + si.Cb = uint32(unsafe.Sizeof(*si)) si.Flags = STARTF_USESTDHANDLES + if attr.HideWindow { + si.Flags |= STARTF_USESHOWWINDOW + si.ShowWindow = SW_HIDE + } si.StdInput = fd[0] si.StdOutput = fd[1] si.StdErr = fd[2] diff --git a/src/pkg/syscall/mkerrors.sh b/src/pkg/syscall/mkerrors.sh index 68a16842a..0bfd9af1d 100755 --- a/src/pkg/syscall/mkerrors.sh +++ b/src/pkg/syscall/mkerrors.sh @@ -47,6 +47,7 @@ includes_Darwin=' #include <sys/sysctl.h> #include <sys/mman.h> #include <sys/wait.h> +#include <net/bpf.h> #include <net/if.h> #include <net/route.h> #include <netinet/in.h> @@ -134,6 +135,7 @@ done $2 ~ /^SIOC/ || $2 ~ /^(IFF|NET_RT|RTM|RTF|RTV|RTA|RTAX)_/ || $2 ~ /^BIOC/ || + $2 !~ /^(BPF_TIMEVAL)$/ && $2 ~ /^(BPF|DLT)_/ || $2 !~ "WMESGLEN" && $2 ~ /^W[A-Z0-9]+$/ {printf("\t$%s = %s,\n", $2, $2)} diff --git a/src/pkg/syscall/syscall.go b/src/pkg/syscall/syscall.go index 2a9ffd4af..157abaa8b 100644 --- a/src/pkg/syscall/syscall.go +++ b/src/pkg/syscall/syscall.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package contains an interface to the low-level operating system +// Package syscall contains an interface to the low-level operating system // primitives. The details vary depending on the underlying system. // Its primary use is inside other packages that provide a more portable // interface to the system, such as "os", "time" and "net". Use those diff --git a/src/pkg/syscall/syscall_linux.go b/src/pkg/syscall/syscall_linux.go index 2b221bd60..4a3797c20 100644 --- a/src/pkg/syscall/syscall_linux.go +++ b/src/pkg/syscall/syscall_linux.go @@ -814,6 +814,13 @@ func Munmap(b []byte) (errno int) { return mapper.Munmap(b) } +//sys Madvise(b []byte, advice int) (errno int) +//sys Mprotect(b []byte, prot int) (errno int) +//sys Mlock(b []byte) (errno int) +//sys Munlock(b []byte) (errno int) +//sys Mlockall(flags int) (errno int) +//sys Munlockall() (errno int) + /* * Unimplemented */ @@ -868,12 +875,9 @@ func Munmap(b []byte) (errno int) { // LookupDcookie // Lremovexattr // Lsetxattr -// Madvise // Mbind // MigratePages // Mincore -// Mlock -// Mmap // ModifyLdt // Mount // MovePages @@ -890,9 +894,6 @@ func Munmap(b []byte) (errno int) { // Msgrcv // Msgsnd // Msync -// Munlock -// Munlockall -// Munmap // Newfstatat // Nfsservctl // Personality diff --git a/src/pkg/syscall/syscall_linux_arm.go b/src/pkg/syscall/syscall_linux_arm.go index 6472c4db5..458745885 100644 --- a/src/pkg/syscall/syscall_linux_arm.go +++ b/src/pkg/syscall/syscall_linux_arm.go @@ -24,7 +24,6 @@ func NsecToTimeval(nsec int64) (tv Timeval) { } // Pread and Pwrite are special: they insert padding before the int64. -// (Ftruncate and truncate are not; go figure.) func Pread(fd int, p []byte, offset int64) (n int, errno int) { var _p0 unsafe.Pointer @@ -48,6 +47,20 @@ func Pwrite(fd int, p []byte, offset int64) (n int, errno int) { return } +func Ftruncate(fd int, length int64) (errno int) { + // ARM EABI requires 64-bit arguments should be put in a pair + // of registers from an even register number. + _, _, e1 := Syscall6(SYS_FTRUNCATE64, uintptr(fd), 0, uintptr(length), uintptr(length>>32), 0, 0) + errno = int(e1) + return +} + +func Truncate(path string, length int64) (errno int) { + _, _, e1 := Syscall6(SYS_TRUNCATE64, uintptr(unsafe.Pointer(StringBytePtr(path))), 0, uintptr(length), uintptr(length>>32), 0, 0) + errno = int(e1) + return +} + // Seek is defined in assembly. func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) @@ -72,7 +85,6 @@ func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) //sys Fchown(fd int, uid int, gid int) (errno int) //sys Fstat(fd int, stat *Stat_t) (errno int) = SYS_FSTAT64 //sys Fstatfs(fd int, buf *Statfs_t) (errno int) = SYS_FSTATFS64 -//sys Ftruncate(fd int, length int64) (errno int) = SYS_FTRUNCATE64 //sysnb Getegid() (egid int) //sysnb Geteuid() (euid int) //sysnb Getgid() (gid int) @@ -92,7 +104,6 @@ func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) //sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, errno int) //sys Stat(path string, stat *Stat_t) (errno int) = SYS_STAT64 //sys Statfs(path string, buf *Statfs_t) (errno int) = SYS_STATFS64 -//sys Truncate(path string, length int64) (errno int) = SYS_TRUNCATE64 // Vsyscalls on amd64. //sysnb Gettimeofday(tv *Timeval) (errno int) diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index 4ac2154c8..1fbb3ccbf 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -220,9 +220,12 @@ func Open(path string, mode int, perm uint32) (fd int, errno int) { var createmode uint32 switch { case mode&O_CREAT != 0: - if mode&O_EXCL != 0 { + switch { + case mode&O_EXCL != 0: createmode = CREATE_NEW - } else { + case mode&O_APPEND != 0: + createmode = OPEN_ALWAYS + default: createmode = CREATE_ALWAYS } case mode&O_TRUNC != 0: @@ -247,27 +250,6 @@ func Read(fd int, p []byte) (n int, errno int) { return int(done), 0 } -// TODO(brainman): ReadFile/WriteFile change file offset, therefore -// i use Seek here to preserve semantics of unix pread/pwrite, -// not sure if I should do that - -func Pread(fd int, p []byte, offset int64) (n int, errno int) { - curoffset, e := Seek(fd, 0, 1) - if e != 0 { - return 0, e - } - defer Seek(fd, curoffset, 0) - var o Overlapped - o.OffsetHigh = uint32(offset >> 32) - o.Offset = uint32(offset) - var done uint32 - e = ReadFile(int32(fd), p, &done, &o) - if e != 0 { - return 0, e - } - return int(done), 0 -} - func Write(fd int, p []byte) (n int, errno int) { var done uint32 e := WriteFile(int32(fd), p, &done, nil) @@ -277,23 +259,6 @@ func Write(fd int, p []byte) (n int, errno int) { return int(done), 0 } -func Pwrite(fd int, p []byte, offset int64) (n int, errno int) { - curoffset, e := Seek(fd, 0, 1) - if e != 0 { - return 0, e - } - defer Seek(fd, curoffset, 0) - var o Overlapped - o.OffsetHigh = uint32(offset >> 32) - o.Offset = uint32(offset) - var done uint32 - e = WriteFile(int32(fd), p, &done, &o) - if e != 0 { - return 0, e - } - return int(done), 0 -} - func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) { var w uint32 switch whence { diff --git a/src/pkg/syscall/types_darwin.c b/src/pkg/syscall/types_darwin.c index 4096bcfd9..666923a68 100644 --- a/src/pkg/syscall/types_darwin.c +++ b/src/pkg/syscall/types_darwin.c @@ -29,6 +29,7 @@ Input to godefs. See also mkerrors.sh and mkall.sh #include <sys/types.h> #include <sys/un.h> #include <sys/wait.h> +#include <net/bpf.h> #include <net/if.h> #include <net/if_dl.h> #include <net/if_var.h> @@ -59,6 +60,7 @@ typedef long long $_C_long_long; typedef struct timespec $Timespec; typedef struct timeval $Timeval; +typedef struct timeval32 $Timeval32; // Processes @@ -157,3 +159,19 @@ typedef struct if_data $IfData; typedef struct ifa_msghdr $IfaMsghdr; typedef struct rt_msghdr $RtMsghdr; typedef struct rt_metrics $RtMetrics; + +// Berkeley packet filter + +enum { + $SizeofBpfVersion = sizeof(struct bpf_version), + $SizeofBpfStat = sizeof(struct bpf_stat), + $SizeofBpfProgram = sizeof(struct bpf_program), + $SizeofBpfInsn = sizeof(struct bpf_insn), + $SizeofBpfHdr = sizeof(struct bpf_hdr), +}; + +typedef struct bpf_version $BpfVersion; +typedef struct bpf_stat $BpfStat; +typedef struct bpf_program $BpfProgram; +typedef struct bpf_insn $BpfInsn; +typedef struct bpf_hdr $BpfHdr; diff --git a/src/pkg/syscall/zerrors_darwin_386.go b/src/pkg/syscall/zerrors_darwin_386.go index 48f563f44..7bc1280d6 100644 --- a/src/pkg/syscall/zerrors_darwin_386.go +++ b/src/pkg/syscall/zerrors_darwin_386.go @@ -45,8 +45,109 @@ const ( AF_SYSTEM = 0x20 AF_UNIX = 0x1 AF_UNSPEC = 0 + BIOCFLUSH = 0x20004268 + BIOCGBLEN = 0x40044266 + BIOCGDLT = 0x4004426a + BIOCGDLTLIST = 0xc00c4279 + BIOCGETIF = 0x4020426b + BIOCGHDRCMPLT = 0x40044274 + BIOCGRSIG = 0x40044272 + BIOCGRTIMEOUT = 0x4008426e + BIOCGSEESENT = 0x40044276 + BIOCGSTATS = 0x4008426f + BIOCIMMEDIATE = 0x80044270 + BIOCPROMISC = 0x20004269 + BIOCSBLEN = 0xc0044266 + BIOCSDLT = 0x80044278 + BIOCSETF = 0x80084267 + BIOCSETIF = 0x8020426c + BIOCSHDRCMPLT = 0x80044275 + BIOCSRSIG = 0x80044273 + BIOCSRTIMEOUT = 0x8008426d + BIOCSSEESENT = 0x80044277 + BIOCVERSION = 0x40044271 + BPF_A = 0x10 + BPF_ABS = 0x20 + BPF_ADD = 0 + BPF_ALIGNMENT = 0x4 + BPF_ALU = 0x4 + BPF_AND = 0x50 + BPF_B = 0x10 + BPF_DIV = 0x30 + BPF_H = 0x8 + BPF_IMM = 0 + BPF_IND = 0x40 + BPF_JA = 0 + BPF_JEQ = 0x10 + BPF_JGE = 0x30 + BPF_JGT = 0x20 + BPF_JMP = 0x5 + BPF_JSET = 0x40 + BPF_K = 0 + BPF_LD = 0 + BPF_LDX = 0x1 + BPF_LEN = 0x80 + BPF_LSH = 0x60 + BPF_MAJOR_VERSION = 0x1 + BPF_MAXBUFSIZE = 0x80000 + BPF_MAXINSNS = 0x200 + BPF_MEM = 0x60 + BPF_MEMWORDS = 0x10 + BPF_MINBUFSIZE = 0x20 + BPF_MINOR_VERSION = 0x1 + BPF_MISC = 0x7 + BPF_MSH = 0xa0 + BPF_MUL = 0x20 + BPF_NEG = 0x80 + BPF_OR = 0x40 + BPF_RELEASE = 0x30bb6 + BPF_RET = 0x6 + BPF_RSH = 0x70 + BPF_ST = 0x2 + BPF_STX = 0x3 + BPF_SUB = 0x10 + BPF_TAX = 0 + BPF_TXA = 0x80 + BPF_W = 0 + BPF_X = 0x8 CTL_MAXNAME = 0xc CTL_NET = 0x4 + DLT_APPLE_IP_OVER_IEEE1394 = 0x8a + DLT_ARCNET = 0x7 + DLT_ATM_CLIP = 0x13 + DLT_ATM_RFC1483 = 0xb + DLT_AX25 = 0x3 + DLT_CHAOS = 0x5 + DLT_CHDLC = 0x68 + DLT_C_HDLC = 0x68 + DLT_EN10MB = 0x1 + DLT_EN3MB = 0x2 + DLT_FDDI = 0xa + DLT_IEEE802 = 0x6 + DLT_IEEE802_11 = 0x69 + DLT_IEEE802_11_RADIO = 0x7f + DLT_IEEE802_11_RADIO_AVS = 0xa3 + DLT_LINUX_SLL = 0x71 + DLT_LOOP = 0x6c + DLT_NULL = 0 + DLT_PFLOG = 0x75 + DLT_PFSYNC = 0x12 + DLT_PPP = 0x9 + DLT_PPP_BSDOS = 0x10 + DLT_PPP_SERIAL = 0x32 + DLT_PRONET = 0x4 + DLT_RAW = 0xc + DLT_SLIP = 0x8 + DLT_SLIP_BSDOS = 0xf + DT_BLK = 0x6 + DT_CHR = 0x2 + DT_DIR = 0x4 + DT_FIFO = 0x1 + DT_LNK = 0xa + DT_REG = 0x8 + DT_SOCK = 0xc + DT_UNKNOWN = 0 + DT_WHT = 0xe E2BIG = 0x7 EACCES = 0xd EADDRINUSE = 0x30 @@ -196,6 +297,7 @@ const ( F_GETLK = 0x7 F_GETOWN = 0x5 F_GETPATH = 0x32 + F_GETPROTECTIONCLASS = 0x3e F_GLOBAL_NOCACHE = 0x37 F_LOG2PHYS = 0x31 F_MARKDEPENDENCY = 0x3c @@ -212,6 +314,7 @@ const ( F_SETLK = 0x8 F_SETLKW = 0x9 F_SETOWN = 0x6 + F_SETPROTECTIONCLASS = 0x3f F_SETSIZE = 0x2b F_THAW_FS = 0x36 F_UNLCK = 0x2 @@ -459,6 +562,16 @@ const ( IP_TOS = 0x3 IP_TRAFFIC_MGT_BACKGROUND = 0x41 IP_TTL = 0x4 + MADV_CAN_REUSE = 0x9 + MADV_DONTNEED = 0x4 + MADV_FREE = 0x5 + MADV_FREE_REUSABLE = 0x7 + MADV_FREE_REUSE = 0x8 + MADV_NORMAL = 0 + MADV_RANDOM = 0x1 + MADV_SEQUENTIAL = 0x2 + MADV_WILLNEED = 0x3 + MADV_ZERO_WIRED_PAGES = 0x6 MAP_ANON = 0x1000 MAP_COPY = 0x2 MAP_FILE = 0 @@ -556,6 +669,7 @@ const ( RTF_DYNAMIC = 0x10 RTF_GATEWAY = 0x2 RTF_HOST = 0x4 + RTF_IFREF = 0x4000000 RTF_IFSCOPE = 0x1000000 RTF_LLINFO = 0x400 RTF_LOCAL = 0x200000 @@ -649,6 +763,7 @@ const ( SIOCDIFADDR = 0x80206919 SIOCDIFPHYADDR = 0x80206941 SIOCDLIFADDR = 0x8118691f + SIOCGDRVSPEC = 0xc01c697b SIOCGETSGCNT = 0xc014721c SIOCGETVIFCNT = 0xc014721b SIOCGETVLAN = 0xc020697f @@ -680,8 +795,10 @@ const ( SIOCGLOWAT = 0x40047303 SIOCGPGRP = 0x40047309 SIOCIFCREATE = 0xc0206978 + SIOCIFCREATE2 = 0xc020697a SIOCIFDESTROY = 0x80206979 SIOCRSLVMULTI = 0xc008693b + SIOCSDRVSPEC = 0x801c697b SIOCSETVLAN = 0x8020697e SIOCSHIWAT = 0x80047300 SIOCSIFADDR = 0x8020690c diff --git a/src/pkg/syscall/zerrors_darwin_amd64.go b/src/pkg/syscall/zerrors_darwin_amd64.go index 840ea13ce..d76f09220 100644 --- a/src/pkg/syscall/zerrors_darwin_amd64.go +++ b/src/pkg/syscall/zerrors_darwin_amd64.go @@ -45,8 +45,109 @@ const ( AF_SYSTEM = 0x20 AF_UNIX = 0x1 AF_UNSPEC = 0 + BIOCFLUSH = 0x20004268 + BIOCGBLEN = 0x40044266 + BIOCGDLT = 0x4004426a + BIOCGDLTLIST = 0xc00c4279 + BIOCGETIF = 0x4020426b + BIOCGHDRCMPLT = 0x40044274 + BIOCGRSIG = 0x40044272 + BIOCGRTIMEOUT = 0x4008426e + BIOCGSEESENT = 0x40044276 + BIOCGSTATS = 0x4008426f + BIOCIMMEDIATE = 0x80044270 + BIOCPROMISC = 0x20004269 + BIOCSBLEN = 0xc0044266 + BIOCSDLT = 0x80044278 + BIOCSETF = 0x80104267 + BIOCSETIF = 0x8020426c + BIOCSHDRCMPLT = 0x80044275 + BIOCSRSIG = 0x80044273 + BIOCSRTIMEOUT = 0x8008426d + BIOCSSEESENT = 0x80044277 + BIOCVERSION = 0x40044271 + BPF_A = 0x10 + BPF_ABS = 0x20 + BPF_ADD = 0 + BPF_ALIGNMENT = 0x4 + BPF_ALU = 0x4 + BPF_AND = 0x50 + BPF_B = 0x10 + BPF_DIV = 0x30 + BPF_H = 0x8 + BPF_IMM = 0 + BPF_IND = 0x40 + BPF_JA = 0 + BPF_JEQ = 0x10 + BPF_JGE = 0x30 + BPF_JGT = 0x20 + BPF_JMP = 0x5 + BPF_JSET = 0x40 + BPF_K = 0 + BPF_LD = 0 + BPF_LDX = 0x1 + BPF_LEN = 0x80 + BPF_LSH = 0x60 + BPF_MAJOR_VERSION = 0x1 + BPF_MAXBUFSIZE = 0x80000 + BPF_MAXINSNS = 0x200 + BPF_MEM = 0x60 + BPF_MEMWORDS = 0x10 + BPF_MINBUFSIZE = 0x20 + BPF_MINOR_VERSION = 0x1 + BPF_MISC = 0x7 + BPF_MSH = 0xa0 + BPF_MUL = 0x20 + BPF_NEG = 0x80 + BPF_OR = 0x40 + BPF_RELEASE = 0x30bb6 + BPF_RET = 0x6 + BPF_RSH = 0x70 + BPF_ST = 0x2 + BPF_STX = 0x3 + BPF_SUB = 0x10 + BPF_TAX = 0 + BPF_TXA = 0x80 + BPF_W = 0 + BPF_X = 0x8 CTL_MAXNAME = 0xc CTL_NET = 0x4 + DLT_APPLE_IP_OVER_IEEE1394 = 0x8a + DLT_ARCNET = 0x7 + DLT_ATM_CLIP = 0x13 + DLT_ATM_RFC1483 = 0xb + DLT_AX25 = 0x3 + DLT_CHAOS = 0x5 + DLT_CHDLC = 0x68 + DLT_C_HDLC = 0x68 + DLT_EN10MB = 0x1 + DLT_EN3MB = 0x2 + DLT_FDDI = 0xa + DLT_IEEE802 = 0x6 + DLT_IEEE802_11 = 0x69 + DLT_IEEE802_11_RADIO = 0x7f + DLT_IEEE802_11_RADIO_AVS = 0xa3 + DLT_LINUX_SLL = 0x71 + DLT_LOOP = 0x6c + DLT_NULL = 0 + DLT_PFLOG = 0x75 + DLT_PFSYNC = 0x12 + DLT_PPP = 0x9 + DLT_PPP_BSDOS = 0x10 + DLT_PPP_SERIAL = 0x32 + DLT_PRONET = 0x4 + DLT_RAW = 0xc + DLT_SLIP = 0x8 + DLT_SLIP_BSDOS = 0xf + DT_BLK = 0x6 + DT_CHR = 0x2 + DT_DIR = 0x4 + DT_FIFO = 0x1 + DT_LNK = 0xa + DT_REG = 0x8 + DT_SOCK = 0xc + DT_UNKNOWN = 0 + DT_WHT = 0xe E2BIG = 0x7 EACCES = 0xd EADDRINUSE = 0x30 @@ -196,6 +297,7 @@ const ( F_GETLK = 0x7 F_GETOWN = 0x5 F_GETPATH = 0x32 + F_GETPROTECTIONCLASS = 0x3e F_GLOBAL_NOCACHE = 0x37 F_LOG2PHYS = 0x31 F_MARKDEPENDENCY = 0x3c @@ -212,6 +314,7 @@ const ( F_SETLK = 0x8 F_SETLKW = 0x9 F_SETOWN = 0x6 + F_SETPROTECTIONCLASS = 0x3f F_SETSIZE = 0x2b F_THAW_FS = 0x36 F_UNLCK = 0x2 @@ -459,6 +562,16 @@ const ( IP_TOS = 0x3 IP_TRAFFIC_MGT_BACKGROUND = 0x41 IP_TTL = 0x4 + MADV_CAN_REUSE = 0x9 + MADV_DONTNEED = 0x4 + MADV_FREE = 0x5 + MADV_FREE_REUSABLE = 0x7 + MADV_FREE_REUSE = 0x8 + MADV_NORMAL = 0 + MADV_RANDOM = 0x1 + MADV_SEQUENTIAL = 0x2 + MADV_WILLNEED = 0x3 + MADV_ZERO_WIRED_PAGES = 0x6 MAP_ANON = 0x1000 MAP_COPY = 0x2 MAP_FILE = 0 @@ -556,6 +669,7 @@ const ( RTF_DYNAMIC = 0x10 RTF_GATEWAY = 0x2 RTF_HOST = 0x4 + RTF_IFREF = 0x4000000 RTF_IFSCOPE = 0x1000000 RTF_LLINFO = 0x400 RTF_LOCAL = 0x200000 @@ -649,6 +763,7 @@ const ( SIOCDIFADDR = 0x80206919 SIOCDIFPHYADDR = 0x80206941 SIOCDLIFADDR = 0x8118691f + SIOCGDRVSPEC = 0xc028697b SIOCGETSGCNT = 0xc014721c SIOCGETVIFCNT = 0xc014721b SIOCGETVLAN = 0xc020697f @@ -680,8 +795,10 @@ const ( SIOCGLOWAT = 0x40047303 SIOCGPGRP = 0x40047309 SIOCIFCREATE = 0xc0206978 + SIOCIFCREATE2 = 0xc020697a SIOCIFDESTROY = 0x80206979 SIOCRSLVMULTI = 0xc010693b + SIOCSDRVSPEC = 0x8028697b SIOCSETVLAN = 0x8020697e SIOCSHIWAT = 0x80047300 SIOCSIFADDR = 0x8020690c diff --git a/src/pkg/syscall/zsyscall_linux_386.go b/src/pkg/syscall/zsyscall_linux_386.go index 83f3bade1..4f331aa22 100644 --- a/src/pkg/syscall/zsyscall_linux_386.go +++ b/src/pkg/syscall/zsyscall_linux_386.go @@ -773,6 +773,78 @@ func munmap(addr uintptr, length uintptr) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT +func Madvise(b []byte, advice int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mprotect(b []byte, prot int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlockall(flags int) (errno int) { + _, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlockall() (errno int) { + _, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + func Chown(path string, uid int, gid int) (errno int) { _, _, e1 := Syscall(SYS_CHOWN32, uintptr(unsafe.Pointer(StringBytePtr(path))), uintptr(uid), uintptr(gid)) errno = int(e1) diff --git a/src/pkg/syscall/zsyscall_linux_amd64.go b/src/pkg/syscall/zsyscall_linux_amd64.go index c054349c6..19501dbfa 100644 --- a/src/pkg/syscall/zsyscall_linux_amd64.go +++ b/src/pkg/syscall/zsyscall_linux_amd64.go @@ -773,6 +773,78 @@ func munmap(addr uintptr, length uintptr) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT +func Madvise(b []byte, advice int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mprotect(b []byte, prot int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlockall(flags int) (errno int) { + _, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlockall() (errno int) { + _, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + func Chown(path string, uid int, gid int) (errno int) { _, _, e1 := Syscall(SYS_CHOWN, uintptr(unsafe.Pointer(StringBytePtr(path))), uintptr(uid), uintptr(gid)) errno = int(e1) diff --git a/src/pkg/syscall/zsyscall_linux_arm.go b/src/pkg/syscall/zsyscall_linux_arm.go index 49d164a3c..db49b6482 100644 --- a/src/pkg/syscall/zsyscall_linux_arm.go +++ b/src/pkg/syscall/zsyscall_linux_arm.go @@ -773,6 +773,78 @@ func munmap(addr uintptr, length uintptr) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT +func Madvise(b []byte, advice int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mprotect(b []byte, prot int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlockall(flags int) (errno int) { + _, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlockall() (errno int) { + _, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + func accept(s int, rsa *RawSockaddrAny, addrlen *_Socklen) (fd int, errno int) { r0, _, e1 := Syscall(SYS_ACCEPT, uintptr(s), uintptr(unsafe.Pointer(rsa)), uintptr(unsafe.Pointer(addrlen))) fd = int(r0) @@ -942,14 +1014,6 @@ func Fstatfs(fd int, buf *Statfs_t) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT -func Ftruncate(fd int, length int64) (errno int) { - _, _, e1 := Syscall(SYS_FTRUNCATE64, uintptr(fd), uintptr(length>>32), uintptr(length)) - errno = int(e1) - return -} - -// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT - func Getegid() (egid int) { r0, _, _ := RawSyscall(SYS_GETEGID, 0, 0, 0) egid = int(r0) @@ -1104,14 +1168,6 @@ func Statfs(path string, buf *Statfs_t) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT -func Truncate(path string, length int64) (errno int) { - _, _, e1 := Syscall(SYS_TRUNCATE64, uintptr(unsafe.Pointer(StringBytePtr(path))), uintptr(length>>32), uintptr(length)) - errno = int(e1) - return -} - -// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT - func Gettimeofday(tv *Timeval) (errno int) { _, _, e1 := RawSyscall(SYS_GETTIMEOFDAY, uintptr(unsafe.Pointer(tv)), 0, 0) errno = int(e1) diff --git a/src/pkg/syscall/ztypes_darwin_386.go b/src/pkg/syscall/ztypes_darwin_386.go index 736c654ab..b3541778e 100644 --- a/src/pkg/syscall/ztypes_darwin_386.go +++ b/src/pkg/syscall/ztypes_darwin_386.go @@ -29,6 +29,11 @@ const ( SizeofIfaMsghdr = 0x14 SizeofRtMsghdr = 0x5c SizeofRtMetrics = 0x38 + SizeofBpfVersion = 0x4 + SizeofBpfStat = 0x8 + SizeofBpfProgram = 0x8 + SizeofBpfInsn = 0x8 + SizeofBpfHdr = 0x14 ) // Types @@ -334,3 +339,33 @@ type RtMetrics struct { Pksent uint32 Filler [4]uint32 } + +type BpfVersion struct { + Major uint16 + Minor uint16 +} + +type BpfStat struct { + Recv uint32 + Drop uint32 +} + +type BpfProgram struct { + Len uint32 + Insns *BpfInsn +} + +type BpfInsn struct { + Code uint16 + Jt uint8 + Jf uint8 + K uint32 +} + +type BpfHdr struct { + Tstamp Timeval + Caplen uint32 + Datalen uint32 + Hdrlen uint16 + Pad_godefs_0 [2]byte +} diff --git a/src/pkg/syscall/ztypes_darwin_amd64.go b/src/pkg/syscall/ztypes_darwin_amd64.go index 936a4e804..d61c8b8de 100644 --- a/src/pkg/syscall/ztypes_darwin_amd64.go +++ b/src/pkg/syscall/ztypes_darwin_amd64.go @@ -29,6 +29,11 @@ const ( SizeofIfaMsghdr = 0x14 SizeofRtMsghdr = 0x5c SizeofRtMetrics = 0x38 + SizeofBpfVersion = 0x4 + SizeofBpfStat = 0x8 + SizeofBpfProgram = 0x10 + SizeofBpfInsn = 0x8 + SizeofBpfHdr = 0x14 ) // Types @@ -52,6 +57,11 @@ type Timeval struct { Pad_godefs_0 [4]byte } +type Timeval32 struct { + Sec int32 + Usec int32 +} + type Rusage struct { Utime Timeval Stime Timeval @@ -229,7 +239,7 @@ type Msghdr struct { Name *byte Namelen uint32 Pad_godefs_0 [4]byte - Iov uint64 + Iov *Iovec Iovlen int32 Pad_godefs_1 [4]byte Control *byte @@ -292,7 +302,7 @@ type IfData struct { Noproto uint32 Recvtiming uint32 Xmittiming uint32 - Lastchange [8]byte /* timeval32 */ + Lastchange Timeval32 Unused2 uint32 Hwassist uint32 Reserved1 uint32 @@ -339,3 +349,34 @@ type RtMetrics struct { Pksent uint32 Filler [4]uint32 } + +type BpfVersion struct { + Major uint16 + Minor uint16 +} + +type BpfStat struct { + Recv uint32 + Drop uint32 +} + +type BpfProgram struct { + Len uint32 + Pad_godefs_0 [4]byte + Insns *BpfInsn +} + +type BpfInsn struct { + Code uint16 + Jt uint8 + Jf uint8 + K uint32 +} + +type BpfHdr struct { + Tstamp Timeval32 + Caplen uint32 + Datalen uint32 + Hdrlen uint16 + Pad_godefs_0 [2]byte +} diff --git a/src/pkg/syscall/ztypes_windows_386.go b/src/pkg/syscall/ztypes_windows_386.go index 56d4198dc..3a50be14c 100644 --- a/src/pkg/syscall/ztypes_windows_386.go +++ b/src/pkg/syscall/ztypes_windows_386.go @@ -77,6 +77,7 @@ const ( HANDLE_FLAG_INHERIT = 0x00000001 STARTF_USESTDHANDLES = 0x00000100 + STARTF_USESHOWWINDOW = 0x00000001 DUPLICATE_CLOSE_SOURCE = 0x00000001 DUPLICATE_SAME_ACCESS = 0x00000002 @@ -240,6 +241,25 @@ type ByHandleFileInformation struct { FileIndexLow uint32 } +// ShowWindow constants +const ( + // winuser.h + SW_HIDE = 0 + SW_NORMAL = 1 + SW_SHOWNORMAL = 1 + SW_SHOWMINIMIZED = 2 + SW_SHOWMAXIMIZED = 3 + SW_MAXIMIZE = 3 + SW_SHOWNOACTIVATE = 4 + SW_SHOW = 5 + SW_MINIMIZE = 6 + SW_SHOWMINNOACTIVE = 7 + SW_SHOWNA = 8 + SW_RESTORE = 9 + SW_SHOWDEFAULT = 10 + SW_FORCEMINIMIZE = 11 +) + type StartupInfo struct { Cb uint32 _ *uint16 diff --git a/src/pkg/syslog/syslog.go b/src/pkg/syslog/syslog.go index 4ada113f1..693337212 100644 --- a/src/pkg/syslog/syslog.go +++ b/src/pkg/syslog/syslog.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The syslog package provides a simple interface to -// the system log service. It can send messages to the -// syslog daemon using UNIX domain sockets, UDP, or +// Package syslog provides a simple interface to the system log service. It +// can send messages to the syslog daemon using UNIX domain sockets, UDP, or // TCP connections. package syslog diff --git a/src/pkg/syslog/syslog_test.go b/src/pkg/syslog/syslog_test.go index 2958bcb1f..4816ddf2a 100644 --- a/src/pkg/syslog/syslog_test.go +++ b/src/pkg/syslog/syslog_test.go @@ -52,6 +52,10 @@ func TestNewLogger(t *testing.T) { } func TestDial(t *testing.T) { + if testing.Short() { + // Depends on syslog daemon running, and sometimes it's not. + t.Logf("skipping syslog test during -short") + } l, err := Dial("", "", LOG_ERR, "syslog_test") if err != nil { t.Fatalf("Dial() failed: %s", err) diff --git a/src/pkg/tabwriter/tabwriter.go b/src/pkg/tabwriter/tabwriter.go index 848703e8c..d91a07db2 100644 --- a/src/pkg/tabwriter/tabwriter.go +++ b/src/pkg/tabwriter/tabwriter.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The tabwriter package implements a write filter (tabwriter.Writer) -// that translates tabbed columns in input into properly aligned text. +// Package tabwriter implements a write filter (tabwriter.Writer) that +// translates tabbed columns in input into properly aligned text. // // The package is using the Elastic Tabstops algorithm described at // http://nickgravgaard.com/elastictabstops/index.html. diff --git a/src/pkg/template/template.go b/src/pkg/template/template.go index 28872dbee..253207852 100644 --- a/src/pkg/template/template.go +++ b/src/pkg/template/template.go @@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. /* - Data-driven templates for generating textual output such as - HTML. + Package template implements data-driven templates for generating textual + output such as HTML. Templates are executed by applying them to a data structure. Annotations in the template refer to elements of the data @@ -646,7 +646,7 @@ func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value } return av.FieldByName(name) case reflect.Map: - if v := av.MapIndex(reflect.NewValue(name)); v.IsValid() { + if v := av.MapIndex(reflect.ValueOf(name)); v.IsValid() { return v } return reflect.Zero(typ.Elem()) @@ -797,7 +797,7 @@ func (t *Template) executeElement(i int, st *state) int { return elem.end } e := t.elems.At(i) - t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e) + t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.ValueOf(e).Interface(), e) return 0 } @@ -980,7 +980,7 @@ func (t *Template) ParseFile(filename string) (err os.Error) { // generating output to wr. func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) { // Extract the driver data. - val := reflect.NewValue(data) + val := reflect.ValueOf(data) defer checkError(&err) t.p = 0 t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr}) diff --git a/src/pkg/testing/iotest/reader.go b/src/pkg/testing/iotest/reader.go index 647520a09..e4003d744 100644 --- a/src/pkg/testing/iotest/reader.go +++ b/src/pkg/testing/iotest/reader.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The iotest package implements Readers and Writers -// useful only for testing. +// Package iotest implements Readers and Writers useful only for testing. package iotest import ( diff --git a/src/pkg/testing/quick/quick.go b/src/pkg/testing/quick/quick.go index 52fd38d9c..756a60e13 100644 --- a/src/pkg/testing/quick/quick.go +++ b/src/pkg/testing/quick/quick.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements utility functions to help with black box testing. +// Package quick implements utility functions to help with black box testing. package quick import ( @@ -59,37 +59,37 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { switch concrete := t; concrete.Kind() { case reflect.Bool: - return reflect.NewValue(rand.Int()&1 == 0), true + return reflect.ValueOf(rand.Int()&1 == 0), true case reflect.Float32: - return reflect.NewValue(randFloat32(rand)), true + return reflect.ValueOf(randFloat32(rand)), true case reflect.Float64: - return reflect.NewValue(randFloat64(rand)), true + return reflect.ValueOf(randFloat64(rand)), true case reflect.Complex64: - return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true + return reflect.ValueOf(complex(randFloat32(rand), randFloat32(rand))), true case reflect.Complex128: - return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true + return reflect.ValueOf(complex(randFloat64(rand), randFloat64(rand))), true case reflect.Int16: - return reflect.NewValue(int16(randInt64(rand))), true + return reflect.ValueOf(int16(randInt64(rand))), true case reflect.Int32: - return reflect.NewValue(int32(randInt64(rand))), true + return reflect.ValueOf(int32(randInt64(rand))), true case reflect.Int64: - return reflect.NewValue(randInt64(rand)), true + return reflect.ValueOf(randInt64(rand)), true case reflect.Int8: - return reflect.NewValue(int8(randInt64(rand))), true + return reflect.ValueOf(int8(randInt64(rand))), true case reflect.Int: - return reflect.NewValue(int(randInt64(rand))), true + return reflect.ValueOf(int(randInt64(rand))), true case reflect.Uint16: - return reflect.NewValue(uint16(randInt64(rand))), true + return reflect.ValueOf(uint16(randInt64(rand))), true case reflect.Uint32: - return reflect.NewValue(uint32(randInt64(rand))), true + return reflect.ValueOf(uint32(randInt64(rand))), true case reflect.Uint64: - return reflect.NewValue(uint64(randInt64(rand))), true + return reflect.ValueOf(uint64(randInt64(rand))), true case reflect.Uint8: - return reflect.NewValue(uint8(randInt64(rand))), true + return reflect.ValueOf(uint8(randInt64(rand))), true case reflect.Uint: - return reflect.NewValue(uint(randInt64(rand))), true + return reflect.ValueOf(uint(randInt64(rand))), true case reflect.Uintptr: - return reflect.NewValue(uintptr(randInt64(rand))), true + return reflect.ValueOf(uintptr(randInt64(rand))), true case reflect.Map: numElems := rand.Intn(complexSize) m := reflect.MakeMap(concrete) @@ -107,8 +107,8 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { if !ok { return reflect.Value{}, false } - p := reflect.Zero(concrete) - p.Set(v.Addr()) + p := reflect.New(concrete.Elem()) + p.Elem().Set(v) return p, true case reflect.Slice: numElems := rand.Intn(complexSize) @@ -127,9 +127,9 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { for i := 0; i < numChars; i++ { codePoints[i] = rand.Intn(0x10ffff) } - return reflect.NewValue(string(codePoints)), true + return reflect.ValueOf(string(codePoints)), true case reflect.Struct: - s := reflect.Zero(t) + s := reflect.New(t).Elem() for i := 0; i < s.NumField(); i++ { v, ok := Value(concrete.Field(i).Type, rand) if !ok { @@ -336,7 +336,7 @@ func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand } func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) { - v = reflect.NewValue(f) + v = reflect.ValueOf(f) ok = v.Kind() == reflect.Func if !ok { return diff --git a/src/pkg/testing/quick/quick_test.go b/src/pkg/testing/quick/quick_test.go index b126e4a16..f2618c3c2 100644 --- a/src/pkg/testing/quick/quick_test.go +++ b/src/pkg/testing/quick/quick_test.go @@ -102,7 +102,7 @@ type myStruct struct { } func (m myStruct) Generate(r *rand.Rand, _ int) reflect.Value { - return reflect.NewValue(myStruct{x: 42}) + return reflect.ValueOf(myStruct{x: 42}) } func myStructProperty(in myStruct) bool { return in.x == 42 } diff --git a/src/pkg/testing/script/script.go b/src/pkg/testing/script/script.go index b18018497..afb286f5b 100644 --- a/src/pkg/testing/script/script.go +++ b/src/pkg/testing/script/script.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package aids in the testing of code that uses channels. +// Package script aids in the testing of code that uses channels. package script import ( @@ -134,19 +134,19 @@ type empty struct { } func newEmptyInterface(e empty) reflect.Value { - return reflect.NewValue(e).Field(0) + return reflect.ValueOf(e).Field(0) } func (s Send) send() { // With reflect.ChanValue.Send, we must match the types exactly. So, if // s.Channel is a chan interface{} we convert s.Value to an interface{} // first. - c := reflect.NewValue(s.Channel) + c := reflect.ValueOf(s.Channel) var v reflect.Value if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 { v = newEmptyInterface(empty{s.Value}) } else { - v = reflect.NewValue(s.Value) + v = reflect.ValueOf(s.Value) } c.Send(v) } @@ -162,7 +162,7 @@ func (s Close) getSend() sendAction { return s } func (s Close) getChannel() interface{} { return s.Channel } -func (s Close) send() { reflect.NewValue(s.Channel).Close() } +func (s Close) send() { reflect.ValueOf(s.Channel).Close() } // A ReceivedUnexpected error results if no active Events match a value // received from a channel. @@ -278,7 +278,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { continue } c := event.action.getChannel() - if reflect.NewValue(c).Kind() != reflect.Chan { + if reflect.ValueOf(c).Kind() != reflect.Chan { return nil, SetupError("one of the channel values is not a channel") } @@ -303,7 +303,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { // channel repeatedly, wrapping them up as either a channelRecv or // channelClosed structure, and forwards them to the multiplex channel. func recvValues(multiplex chan<- interface{}, channel interface{}) { - c := reflect.NewValue(channel) + c := reflect.ValueOf(channel) for { v, ok := c.Recv() diff --git a/src/pkg/testing/testing.go b/src/pkg/testing/testing.go index 1e65528ef..8781b207d 100644 --- a/src/pkg/testing/testing.go +++ b/src/pkg/testing/testing.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The testing package provides support for automated testing of Go packages. +// Package testing provides support for automated testing of Go packages. // It is intended to be used in concert with the ``gotest'' utility, which automates // execution of any function of the form // func TestXxx(*testing.T) diff --git a/src/pkg/time/time.go b/src/pkg/time/time.go index 40338f775..a0480786a 100644 --- a/src/pkg/time/time.go +++ b/src/pkg/time/time.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The time package provides functionality for measuring and -// displaying time. +// Package time provides functionality for measuring and displaying time. package time // Days of the week. diff --git a/src/pkg/time/zoneinfo_unix.go b/src/pkg/time/zoneinfo_unix.go index 6685da747..42659ed60 100644 --- a/src/pkg/time/zoneinfo_unix.go +++ b/src/pkg/time/zoneinfo_unix.go @@ -17,8 +17,6 @@ import ( const ( headerSize = 4 + 16 + 4*7 - zoneDir = "/usr/share/zoneinfo/" - zoneDir2 = "/usr/share/lib/zoneinfo/" ) // Simple I/O interface to binary blob of data. @@ -211,16 +209,22 @@ func setupZone() { // no $TZ means use the system default /etc/localtime. // $TZ="" means use UTC. // $TZ="foo" means use /usr/share/zoneinfo/foo. + // Many systems use /usr/share/zoneinfo, Solaris 2 has + // /usr/share/lib/zoneinfo, IRIX 6 has /usr/lib/locale/TZ. + zoneDirs := []string{"/usr/share/zoneinfo/", + "/usr/share/lib/zoneinfo/", + "/usr/lib/locale/TZ/"} tz, err := os.Getenverror("TZ") switch { case err == os.ENOENV: zones, _ = readinfofile("/etc/localtime") case len(tz) > 0: - var ok bool - zones, ok = readinfofile(zoneDir + tz) - if !ok { - zones, _ = readinfofile(zoneDir2 + tz) + for _, zoneDir := range zoneDirs { + var ok bool + if zones, ok = readinfofile(zoneDir + tz); ok { + break + } } case len(tz) == 0: // do nothing: use UTC diff --git a/src/pkg/try/try.go b/src/pkg/try/try.go index 1171c80c2..2a3dbf987 100644 --- a/src/pkg/try/try.go +++ b/src/pkg/try/try.go @@ -67,7 +67,7 @@ func printSlice(firstArg string, args []interface{}) { func tryMethods(pkg, firstArg string, args []interface{}) { defer func() { recover() }() // Is the first argument something with methods? - v := reflect.NewValue(args[0]) + v := reflect.ValueOf(args[0]) typ := v.Type() if typ.NumMethod() == 0 { return @@ -90,7 +90,7 @@ func tryMethod(pkg, firstArg string, method reflect.Method, args []interface{}) // tryFunction sees if fn satisfies the arguments. func tryFunction(pkg, name string, fn interface{}, args []interface{}) { defer func() { recover() }() - rfn := reflect.NewValue(fn) + rfn := reflect.ValueOf(fn) typ := rfn.Type() tryOneFunction(pkg, "", name, typ, rfn, args) } @@ -120,7 +120,7 @@ func tryOneFunction(pkg, firstArg, name string, typ reflect.Type, rfn reflect.Va // Build the call args. argsVal := make([]reflect.Value, typ.NumIn()+typ.NumOut()) for i, a := range args { - argsVal[i] = reflect.NewValue(a) + argsVal[i] = reflect.ValueOf(a) } // Call the function and see if the results are as expected. resultVal := rfn.Call(argsVal[:typ.NumIn()]) @@ -161,7 +161,7 @@ func tryOneFunction(pkg, firstArg, name string, typ reflect.Type, rfn reflect.Va // compatible reports whether the argument is compatible with the type. func compatible(arg interface{}, typ reflect.Type) bool { - if reflect.Typeof(arg) == typ { + if reflect.TypeOf(arg) == typ { return true } if arg == nil { diff --git a/src/pkg/unicode/letter.go b/src/pkg/unicode/letter.go index 9380624fd..382c6eb3f 100644 --- a/src/pkg/unicode/letter.go +++ b/src/pkg/unicode/letter.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides data and functions to test some properties of Unicode code points. +// Package unicode provides data and functions to test some properties of +// Unicode code points. package unicode const ( diff --git a/src/pkg/unsafe/unsafe.go b/src/pkg/unsafe/unsafe.go index 3cd4cff6e..8507bed52 100644 --- a/src/pkg/unsafe/unsafe.go +++ b/src/pkg/unsafe/unsafe.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The unsafe package contains operations that step around the type safety of Go programs. + Package unsafe contains operations that step around the type safety of Go programs. */ package unsafe diff --git a/src/pkg/utf8/utf8.go b/src/pkg/utf8/utf8.go index 455499e4d..f542358d6 100644 --- a/src/pkg/utf8/utf8.go +++ b/src/pkg/utf8/utf8.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Functions and constants to support text encoded in UTF-8. -// This package calls a Unicode character a rune for brevity. +// Package utf8 implements functions and constants to support text encoded in +// UTF-8. This package calls a Unicode character a rune for brevity. package utf8 import "unicode" // only needed for a couple of constants diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 1119b2d34..376265236 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -150,6 +150,7 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } ws := newConn(origin, location, protocol, buf, rwc) + ws.Request = req f(ws) } diff --git a/src/pkg/websocket/websocket.go b/src/pkg/websocket/websocket.go index d5996abe1..edde61b4a 100644 --- a/src/pkg/websocket/websocket.go +++ b/src/pkg/websocket/websocket.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The websocket package implements a client and server for the Web Socket protocol. +// Package websocket implements a client and server for the Web Socket protocol. // The protocol is defined at http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol package websocket @@ -13,6 +13,7 @@ import ( "bufio" "crypto/md5" "encoding/binary" + "http" "io" "net" "os" @@ -43,6 +44,8 @@ type Conn struct { Location string // The subprotocol for the Web Socket. Protocol string + // The initial http Request (for the Server side only). + Request *http.Request buf *bufio.ReadWriter rwc io.ReadWriteCloser diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go index 8b3cf8925..10f88dfd1 100644 --- a/src/pkg/websocket/websocket_test.go +++ b/src/pkg/websocket/websocket_test.go @@ -186,11 +186,12 @@ func TestTrailingSpaces(t *testing.T) { once.Do(startServer) for i := 0; i < 30; i++ { // body - _, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", - "http://localhost/") + ws, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", "http://localhost/") if err != nil { - panic("Dial failed: " + err.String()) + t.Error("Dial failed:", err.String()) + break } + ws.Close() } } diff --git a/src/pkg/xml/read.go b/src/pkg/xml/read.go index a3ddb9d4c..554b2a61b 100644 --- a/src/pkg/xml/read.go +++ b/src/pkg/xml/read.go @@ -139,7 +139,7 @@ import ( // to a freshly allocated value and then mapping the element to that value. // func Unmarshal(r io.Reader, val interface{}) os.Error { - v := reflect.NewValue(val) + v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } @@ -176,7 +176,7 @@ func (e *TagPathError) String() string { // Passing a nil start element indicates that Unmarshal should // read the token stream to find the start element. func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error { - v := reflect.NewValue(val) + v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } @@ -280,7 +280,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { case reflect.Struct: if _, ok := v.Interface().(Name); ok { - v.Set(reflect.NewValue(start.Name)) + v.Set(reflect.ValueOf(start.Name)) break } @@ -316,7 +316,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { if _, ok := v.Interface().(Name); !ok { return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name") } - v.Set(reflect.NewValue(start.Name)) + v.Set(reflect.ValueOf(start.Name)) } // Assign attributes. @@ -508,21 +508,21 @@ Loop: case reflect.String: t.SetString(string(data)) case reflect.Slice: - t.Set(reflect.NewValue(data)) + t.Set(reflect.ValueOf(data)) } switch t := saveComment; t.Kind() { case reflect.String: t.SetString(string(comment)) case reflect.Slice: - t.Set(reflect.NewValue(comment)) + t.Set(reflect.ValueOf(comment)) } switch t := saveXML; t.Kind() { case reflect.String: t.SetString(string(saveXMLData)) case reflect.Slice: - t.Set(reflect.NewValue(saveXMLData)) + t.Set(reflect.ValueOf(saveXMLData)) } return nil diff --git a/src/pkg/xml/read_test.go b/src/pkg/xml/read_test.go index 0e28e73a6..d4ae3700d 100644 --- a/src/pkg/xml/read_test.go +++ b/src/pkg/xml/read_test.go @@ -288,9 +288,7 @@ var pathTests = []interface{}{ func TestUnmarshalPaths(t *testing.T) { for _, pt := range pathTests { - p := reflect.Zero(reflect.NewValue(pt).Type()) - p.Set(reflect.Zero(p.Type().Elem()).Addr()) - v := p.Interface() + v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() if err := Unmarshal(StringReader(pathTestString), v); err != nil { t.Fatalf("Unmarshal: %s", err) } @@ -315,8 +313,8 @@ type BadPathTestB struct { var badPathTests = []struct { v, e interface{} }{ - {&BadPathTestA{}, &TagPathError{reflect.Typeof(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, - {&BadPathTestB{}, &TagPathError{reflect.Typeof(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, + {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, + {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, } func TestUnmarshalBadPaths(t *testing.T) { diff --git a/src/pkg/xml/xml.go b/src/pkg/xml/xml.go index f92abe825..42d8b986e 100644 --- a/src/pkg/xml/xml.go +++ b/src/pkg/xml/xml.go @@ -163,6 +163,13 @@ type Parser struct { // "quot": `"`, Entity map[string]string + // CharsetReader, if non-nil, defines a function to generate + // charset-conversion readers, converting from the provided + // non-UTF-8 charset into UTF-8. If CharsetReader is nil or + // returns an error, parsing stops with an error. One of the + // the CharsetReader's result values must be non-nil. + CharsetReader func(charset string, input io.Reader) (io.Reader, os.Error) + r io.ByteReader buf bytes.Buffer saved *bytes.Buffer @@ -186,17 +193,7 @@ func NewParser(r io.Reader) *Parser { line: 1, Strict: true, } - - // Get efficient byte at a time reader. - // Assume that if reader has its own - // ReadByte, it's efficient enough. - // Otherwise, use bufio. - if rb, ok := r.(io.ByteReader); ok { - p.r = rb - } else { - p.r = bufio.NewReader(r) - } - + p.switchToReader(r) return p } @@ -290,6 +287,18 @@ func (p *Parser) translate(n *Name, isElementName bool) { } } +func (p *Parser) switchToReader(r io.Reader) { + // Get efficient byte at a time reader. + // Assume that if reader has its own + // ReadByte, it's efficient enough. + // Otherwise, use bufio. + if rb, ok := r.(io.ByteReader); ok { + p.r = rb + } else { + p.r = bufio.NewReader(r) + } +} + // Parsing state - stack holds old name space translations // and the current set of open elements. The translations to pop when // ending a given tag are *below* it on the stack, which is @@ -487,6 +496,25 @@ func (p *Parser) RawToken() (Token, os.Error) { } data := p.buf.Bytes() data = data[0 : len(data)-2] // chop ?> + + if target == "xml" { + enc := procInstEncoding(string(data)) + if enc != "" && enc != "utf-8" && enc != "UTF-8" { + if p.CharsetReader == nil { + p.err = fmt.Errorf("xml: encoding %q declared but Parser.CharsetReader is nil", enc) + return nil, p.err + } + newr, err := p.CharsetReader(enc, p.r.(io.Reader)) + if err != nil { + p.err = fmt.Errorf("xml: opening charset %q: %v", enc, err) + return nil, p.err + } + if newr == nil { + panic("CharsetReader returned a nil Reader for charset " + enc) + } + p.switchToReader(newr) + } + } return ProcInst{target, data}, nil case '!': @@ -1633,3 +1661,26 @@ func Escape(w io.Writer, s []byte) { } w.Write(s[last:]) } + +// procInstEncoding parses the `encoding="..."` or `encoding='...'` +// value out of the provided string, returning "" if not found. +func procInstEncoding(s string) string { + // TODO: this parsing is somewhat lame and not exact. + // It works for all actual cases, though. + idx := strings.Index(s, "encoding=") + if idx == -1 { + return "" + } + v := s[idx+len("encoding="):] + if v == "" { + return "" + } + if v[0] != '\'' && v[0] != '"' { + return "" + } + idx = strings.IndexRune(v[1:], int(v[0])) + if idx == -1 { + return "" + } + return v[1 : idx+1] +} diff --git a/src/pkg/xml/xml_test.go b/src/pkg/xml/xml_test.go index 887bc3d14..a99c1919e 100644 --- a/src/pkg/xml/xml_test.go +++ b/src/pkg/xml/xml_test.go @@ -9,6 +9,7 @@ import ( "io" "os" "reflect" + "strings" "testing" ) @@ -96,6 +97,19 @@ var cookedTokens = []Token{ Comment([]byte(" missing final newline ")), } +const testInputAltEncoding = ` +<?xml version="1.0" encoding="x-testing-uppercase"?> +<TAG>VALUE</TAG>` + +var rawTokensAltEncoding = []Token{ + CharData([]byte("\n")), + ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, + CharData([]byte("\n")), + StartElement{Name{"", "tag"}, nil}, + CharData([]byte("value")), + EndElement{Name{"", "tag"}}, +} + var xmlInput = []string{ // unexpected EOF cases "<", @@ -173,7 +187,64 @@ func StringReader(s string) io.Reader { return &stringReader{s, 0} } func TestRawToken(t *testing.T) { p := NewParser(StringReader(testInput)) + testRawToken(t, p, rawTokens) +} + +type downCaser struct { + t *testing.T + r io.ByteReader +} + +func (d *downCaser) ReadByte() (c byte, err os.Error) { + c, err = d.r.ReadByte() + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + return +} + +func (d *downCaser) Read(p []byte) (int, os.Error) { + d.t.Fatalf("unexpected Read call on downCaser reader") + return 0, os.EINVAL +} + +func TestRawTokenAltEncoding(t *testing.T) { + sawEncoding := "" + p := NewParser(StringReader(testInputAltEncoding)) + p.CharsetReader = func(charset string, input io.Reader) (io.Reader, os.Error) { + sawEncoding = charset + if charset != "x-testing-uppercase" { + t.Fatalf("unexpected charset %q", charset) + } + return &downCaser{t, input.(io.ByteReader)}, nil + } + testRawToken(t, p, rawTokensAltEncoding) +} +func TestRawTokenAltEncodingNoConverter(t *testing.T) { + p := NewParser(StringReader(testInputAltEncoding)) + token, err := p.RawToken() + if token == nil { + t.Fatalf("expected a token on first RawToken call") + } + if err != nil { + t.Fatal(err) + } + token, err = p.RawToken() + if token != nil { + t.Errorf("expected a nil token; got %#v", token) + } + if err == nil { + t.Fatalf("expected an error on second RawToken call") + } + const encoding = "x-testing-uppercase" + if !strings.Contains(err.String(), encoding) { + t.Errorf("expected error to contain %q; got error: %v", + encoding, err) + } +} + +func testRawToken(t *testing.T, p *Parser, rawTokens []Token) { for i, want := range rawTokens { have, err := p.RawToken() if err != nil { @@ -483,3 +554,26 @@ func TestDisallowedCharacters(t *testing.T) { } } } + +type procInstEncodingTest struct { + expect, got string +} + +var procInstTests = []struct { + input, expect string +}{ + {`version="1.0" encoding="utf-8"`, "utf-8"}, + {`version="1.0" encoding='utf-8'`, "utf-8"}, + {`version="1.0" encoding='utf-8' `, "utf-8"}, + {`version="1.0" encoding=utf-8`, ""}, + {`encoding="FOO" `, "FOO"}, +} + +func TestProcInstEncoding(t *testing.T) { + for _, test := range procInstTests { + got := procInstEncoding(test.input) + if got != test.expect { + t.Errorf("procInstEncoding(%q) = %q; want %q", test.input, got, test.expect) + } + } +} diff --git a/src/run.bash b/src/run.bash index ea98403f7..bb3d06c45 100755 --- a/src/run.bash +++ b/src/run.bash @@ -33,8 +33,7 @@ xcd() { if $rebuild; then (xcd pkg gomake clean - time gomake - gomake install + time gomake install ) || exit $i fi @@ -43,18 +42,10 @@ gomake testshort ) || exit $? (xcd pkg/sync; -if $rebuild; then - gomake clean; - time gomake -fi GOMAXPROCS=10 gomake testshort ) || exit $? (xcd cmd/ebnflint -if $rebuild; then - gomake clean; - time gomake -fi time gomake test ) || exit $? @@ -83,7 +74,6 @@ gomake clean time gomake ogle ) || exit $? -[ "$GOHOSTOS" == windows ] || (xcd ../doc/progs time ./run ) || exit $? |