diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-04-26 09:55:32 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-04-26 09:55:32 +0200 |
commit | 7b15ed9ef455b6b66c6b376898a88aef5d6a9970 (patch) | |
tree | 3ef530baa80cdf29436ba981f5783be6b4d2202b /src/cmd | |
parent | 50104cc32a498f7517a51c8dc93106c51c7a54b4 (diff) | |
download | golang-7b15ed9ef455b6b66c6b376898a88aef5d6a9970.tar.gz |
Imported Upstream version 2011.04.13upstream/2011.04.13
Diffstat (limited to 'src/cmd')
149 files changed, 24963 insertions, 1464 deletions
diff --git a/src/cmd/5c/peep.c b/src/cmd/5c/peep.c index 8945ee732..c15bf0fc4 100644 --- a/src/cmd/5c/peep.c +++ b/src/cmd/5c/peep.c @@ -1100,7 +1100,7 @@ copyu(Prog *p, Adr *v, Adr *s) if(v->type == D_REG) { if(v->reg <= REGEXT && v->reg > exregoffset) return 2; - if(v->reg == REGARG) + if(v->reg == (uchar)REGARG) return 2; } if(v->type == D_FREG) @@ -1118,7 +1118,7 @@ copyu(Prog *p, Adr *v, Adr *s) case ATEXT: /* funny */ if(v->type == D_REG) - if(v->reg == REGARG) + if(v->reg == (uchar)REGARG) return 3; return 0; } diff --git a/src/cmd/5c/txt.c b/src/cmd/5c/txt.c index f5619f800..4be1f6f62 100644 --- a/src/cmd/5c/txt.c +++ b/src/cmd/5c/txt.c @@ -400,6 +400,10 @@ regsalloc(Node *n, Node *nn) void regaalloc1(Node *n, Node *nn) { + if(REGARG < 0) { + fatal(n, "regaalloc1 and REGARG<0"); + return; + } nodreg(n, nn, REGARG); reg[REGARG]++; curarg = align(curarg, nn->type, Aarg1, nil); diff --git a/src/cmd/5g/ggen.c b/src/cmd/5g/ggen.c index 182d7f147..7197709d4 100644 --- a/src/cmd/5g/ggen.c +++ b/src/cmd/5g/ggen.c @@ -32,7 +32,7 @@ compile(Node *fn) return; // set up domain for labels - labellist = L; + clearlabels(); lno = setlineno(fn); diff --git a/src/cmd/5g/gobj.c b/src/cmd/5g/gobj.c index bf59534b9..acece6c0d 100644 --- a/src/cmd/5g/gobj.c +++ b/src/cmd/5g/gobj.c @@ -268,7 +268,7 @@ static Prog *estrdat; static int gflag; static Prog *savepc; -static void +void data(void) { gflag = debug['g']; @@ -285,7 +285,7 @@ data(void) pc = estrdat; } -static void +void text(void) { if(!savepc) @@ -310,6 +310,29 @@ dumpdata(void) pc = estrdat; } +int +dsname(Sym *sym, int off, char *t, int n) +{ + Prog *p; + + p = gins(ADATA, N, N); + p->from.type = D_OREG; + p->from.name = D_EXTERN; + p->from.etype = TINT32; + p->from.offset = off; + p->from.reg = NREG; + p->from.sym = sym; + + p->reg = n; + + p->to.type = D_SCONST; + p->to.name = D_NONE; + p->to.reg = NREG; + p->to.offset = 0; + memmove(p->to.sval, t, n); + return off + n; +} + /* * make a refer to the data s, s+len * emitting DATA if needed. @@ -317,76 +340,15 @@ dumpdata(void) void datastring(char *s, int len, Addr *a) { - int w; - Prog *p; - Addr ac, ao; - static int gen; - struct { - Strlit lit; - char buf[100]; - } tmp; - - // string - memset(&ao, 0, sizeof(ao)); - ao.type = D_OREG; - ao.name = D_STATIC; - ao.etype = TINT32; - ao.offset = 0; // fill in - ao.reg = NREG; - - // constant - memset(&ac, 0, sizeof(ac)); - ac.type = D_CONST; - ac.name = D_NONE; - ac.offset = 0; // fill in - ac.reg = NREG; - - // huge strings are made static to avoid long names. - if(len > 100) { - snprint(namebuf, sizeof(namebuf), ".string.%d", gen++); - ao.sym = lookup(namebuf); - ao.name = D_STATIC; - } else { - if(len > 0 && s[len-1] == '\0') - len--; - tmp.lit.len = len; - memmove(tmp.lit.s, s, len); - tmp.lit.s[len] = '\0'; - len++; - snprint(namebuf, sizeof(namebuf), "\"%Z\"", &tmp.lit); - ao.sym = pkglookup(namebuf, stringpkg); - ao.name = D_EXTERN; - } - *a = ao; - - // only generate data the first time. - if(ao.sym->flags & SymUniq) - return; - ao.sym->flags |= SymUniq; - - data(); - for(w=0; w<len; w+=8) { - p = pc; - gins(ADATA, N, N); - - // DATA s+w, [NSNAME], $"xxx" - p->from = ao; - p->from.offset = w; - - p->reg = NSNAME; - if(w+8 > len) - p->reg = len-w; - - p->to = ac; - p->to.type = D_SCONST; - p->to.offset = len; - memmove(p->to.sval, s+w, p->reg); - } - p = pc; - ggloblsym(ao.sym, len, ao.name == D_EXTERN); - if(ao.name == D_STATIC) - p->from.name = D_STATIC; - text(); + Sym *sym; + + sym = stringsym(s, len); + a->type = D_OREG; + a->name = D_EXTERN; + a->etype = TINT32; + a->offset = widthptr+4; // skip header + a->reg = NREG; + a->sym = sym; } /* @@ -396,77 +358,15 @@ datastring(char *s, int len, Addr *a) void datagostring(Strlit *sval, Addr *a) { - Prog *p; - Addr ac, ao, ap; - int32 wi, wp; - static int gen; - - memset(&ac, 0, sizeof(ac)); - memset(&ao, 0, sizeof(ao)); - memset(&ap, 0, sizeof(ap)); - - // constant - ac.type = D_CONST; - ac.name = D_NONE; - ac.offset = 0; // fill in - ac.reg = NREG; - - // string len+ptr - ao.type = D_OREG; - ao.name = D_STATIC; // fill in - ao.etype = TINT32; - ao.sym = nil; // fill in - ao.reg = NREG; - - // $string len+ptr - datastring(sval->s, sval->len, &ap); - ap.type = D_CONST; - ap.etype = TINT32; - - wi = types[TUINT32]->width; - wp = types[tptr]->width; - - if(ap.name == D_STATIC) { - // huge strings are made static to avoid long names - snprint(namebuf, sizeof(namebuf), ".gostring.%d", ++gen); - ao.sym = lookup(namebuf); - ao.name = D_STATIC; - } else { - // small strings get named by their contents, - // so that multiple modules using the same string - // can share it. - snprint(namebuf, sizeof(namebuf), "\"%Z\"", sval); - ao.sym = pkglookup(namebuf, gostringpkg); - ao.name = D_EXTERN; - } - - *a = ao; - if(ao.sym->flags & SymUniq) - return; - ao.sym->flags |= SymUniq; - - data(); - // DATA gostring, wp, $cstring - p = pc; - gins(ADATA, N, N); - p->from = ao; - p->reg = wp; - p->to = ap; - - // DATA gostring+wp, wi, $len - p = pc; - gins(ADATA, N, N); - p->from = ao; - p->from.offset = wp; - p->reg = wi; - p->to = ac; - p->to.offset = sval->len; - - p = pc; - ggloblsym(ao.sym, types[TSTRING]->width, ao.name == D_EXTERN); - if(ao.name == D_STATIC) - p->from.name = D_STATIC; - text(); + Sym *sym; + + sym = stringsym(sval->s, sval->len); + a->type = D_OREG; + a->name = D_EXTERN; + a->etype = TINT32; + a->offset = 0; // header + a->reg = NREG; + a->sym = sym; } void diff --git a/src/cmd/5l/asm.c b/src/cmd/5l/asm.c index af6d1dfda..e2583e7c3 100644 --- a/src/cmd/5l/asm.c +++ b/src/cmd/5l/asm.c @@ -1978,11 +1978,16 @@ genasmsym(void (*put)(Sym*, char*, int, vlong, vlong, int, Sym*)) for(h=0; h<NHASH; h++) { for(s=hash[h]; s!=S; s=s->hash) { + if(s->hide) + continue; switch(s->type) { case SCONST: case SRODATA: case SDATA: case SELFDATA: + case STYPE: + case SSTRING: + case SGOSTRING: if(!s->reachable) continue; put(s, s->name, 'D', s->value, s->size, s->version, s->gotype); diff --git a/src/cmd/5l/l.h b/src/cmd/5l/l.h index 2e887dad7..cf5a9990b 100644 --- a/src/cmd/5l/l.h +++ b/src/cmd/5l/l.h @@ -136,6 +136,7 @@ struct Sym uchar dynexport; uchar leaf; uchar stkcheck; + uchar hide; int32 dynid; int32 plt; int32 got; @@ -147,6 +148,7 @@ struct Sym uchar foreign; // called by arm if thumb, by thumb if arm uchar fnptr; // used as fn ptr Sym* hash; // in hash table + Sym* allsym; // in all symbol list Sym* next; // in text or data list Sym* sub; // in SSUB list Sym* outer; // container of sub @@ -202,22 +204,6 @@ struct Count enum { - Sxxx, - - /* order here is order in output file */ - STEXT = 1, - SRODATA, - SELFDATA, - SDATA, - SBSS, - - SXREF, - SFILE, - SCONST, - SDYNIMPORT, - - SSUB = 1<<8, - LFROM = 1<<0, LTO = 1<<1, LPOOL = 1<<2, @@ -280,7 +266,6 @@ enum LEAF = 1<<2, STRINGSZ = 200, - NHASH = 10007, MINSIZ = 64, NENT = 100, MAXIO = 8192, diff --git a/src/cmd/5l/noop.c b/src/cmd/5l/noop.c index a5e66f038..bdcc3cad8 100644 --- a/src/cmd/5l/noop.c +++ b/src/cmd/5l/noop.c @@ -364,14 +364,14 @@ noops(void) p = appendp(p); p->as = ABL; p->scond = C_SCOND_LO; - p->to.type = D_BRANCH; + p->to.type = D_BRANCH; p->to.sym = symmorestack; p->cond = pmorestack; // MOVW.W R14,$-autosize(SP) p = appendp(p); p->as = AMOVW; - p->scond |= C_WBIT; + p->scond |= C_WBIT; p->from.type = D_REG; p->from.reg = REGLINK; p->to.type = D_OREG; @@ -413,14 +413,14 @@ noops(void) // BL runtime.morestack(SB) // modifies LR p = appendp(p); p->as = ABL; - p->to.type = D_BRANCH; + p->to.type = D_BRANCH; p->to.sym = symmorestack; p->cond = pmorestack; // MOVW.W R14,$-autosize(SP) p = appendp(p); p->as = AMOVW; - p->scond |= C_WBIT; + p->scond |= C_WBIT; p->from.type = D_REG; p->from.reg = REGLINK; p->to.type = D_OREG; @@ -450,6 +450,8 @@ noops(void) } } if(thumb){ + diag("thumb not maintained"); + errorexit(); if(cursym->text->mark & LEAF){ if(autosize){ p->as = AADD; @@ -481,7 +483,7 @@ noops(void) q->to.type = D_REG; q->to.reg = REGSP; q->link = p->link; - p->link = q; + p->link = q; } else q = p; @@ -492,6 +494,8 @@ noops(void) break; } if(foreign) { + diag("foreign not maintained"); + errorexit(); // if(foreign) print("ABXRET 3 %s\n", cursym->name); #define R 1 p->as = AMOVW; @@ -530,7 +534,9 @@ noops(void) p->from.reg = REGSP; p->to.type = D_REG; p->to.reg = REGPC; - // no spadj because it doesn't fall through + // If there are instructions following + // this ARET, they come from a branch + // with the same stackframe, so no spadj. } break; diff --git a/src/cmd/5l/softfloat.c b/src/cmd/5l/softfloat.c index fd66b0969..03d8c6d26 100644 --- a/src/cmd/5l/softfloat.c +++ b/src/cmd/5l/softfloat.c @@ -4,6 +4,7 @@ #define EXTERN #include "l.h" +#include "../ld/lib.h" // Software floating point. diff --git a/src/cmd/5l/thumb.c b/src/cmd/5l/thumb.c index b2ba630c3..a6f729bed 100644 --- a/src/cmd/5l/thumb.c +++ b/src/cmd/5l/thumb.c @@ -29,6 +29,7 @@ // THE SOFTWARE. #include "l.h" +#include "../ld/lib.h" static int32 thumboprr(int); static int32 thumboprrr(int, int); diff --git a/src/cmd/6c/peep.c b/src/cmd/6c/peep.c index 13fd25e73..8b82adbf5 100644 --- a/src/cmd/6c/peep.c +++ b/src/cmd/6c/peep.c @@ -797,7 +797,7 @@ copyu(Prog *p, Adr *v, Adr *s) return 3; case ACALL: /* funny */ - if(REGARG >= 0 && v->type == REGARG) + if(REGARG >= 0 && v->type == (uchar)REGARG) return 2; if(s != A) { @@ -810,7 +810,7 @@ copyu(Prog *p, Adr *v, Adr *s) return 3; case ATEXT: /* funny */ - if(REGARG >= 0 && v->type == REGARG) + if(REGARG >= 0 && v->type == (uchar)REGARG) return 3; return 0; } diff --git a/src/cmd/6c/txt.c b/src/cmd/6c/txt.c index a78ba227b..12fc5b498 100644 --- a/src/cmd/6c/txt.c +++ b/src/cmd/6c/txt.c @@ -436,8 +436,10 @@ regsalloc(Node *n, Node *nn) void regaalloc1(Node *n, Node *nn) { - if(REGARG < 0) - diag(n, "regaalloc1 and REGARG<0"); + if(REGARG < 0) { + fatal(n, "regaalloc1 and REGARG<0"); + return; + } nodreg(n, nn, REGARG); reg[REGARG]++; curarg = align(curarg, nn->type, Aarg1, nil); diff --git a/src/cmd/6g/ggen.c b/src/cmd/6g/ggen.c index d9fa1793c..8d89fb164 100644 --- a/src/cmd/6g/ggen.c +++ b/src/cmd/6g/ggen.c @@ -32,7 +32,7 @@ compile(Node *fn) return; // set up domain for labels - labellist = L; + clearlabels(); lno = setlineno(fn); diff --git a/src/cmd/6g/gobj.c b/src/cmd/6g/gobj.c index b667ae48a..507764a3b 100644 --- a/src/cmd/6g/gobj.c +++ b/src/cmd/6g/gobj.c @@ -280,7 +280,7 @@ static Prog *estrdat; static int gflag; static Prog *savepc; -static void +void data(void) { gflag = debug['g']; @@ -297,7 +297,7 @@ data(void) pc = estrdat; } -static void +void text(void) { if(!savepc) @@ -322,6 +322,24 @@ dumpdata(void) pc = estrdat; } +int +dsname(Sym *s, int off, char *t, int n) +{ + Prog *p; + + p = gins(ADATA, N, N); + p->from.type = D_EXTERN; + p->from.index = D_NONE; + p->from.offset = off; + p->from.scale = n; + p->from.sym = s; + + p->to.type = D_SCONST; + p->to.index = D_NONE; + memmove(p->to.sval, t, n); + return off + n; +} + /* * make a refer to the data s, s+len * emitting DATA if needed. @@ -329,74 +347,13 @@ dumpdata(void) void datastring(char *s, int len, Addr *a) { - int w; - Prog *p; - Addr ac, ao; - static int gen; - struct { - Strlit lit; - char buf[100]; - } tmp; - - // string - memset(&ao, 0, sizeof(ao)); - ao.type = D_STATIC; - ao.index = D_NONE; - ao.etype = TINT32; - ao.offset = 0; // fill in - - // constant - memset(&ac, 0, sizeof(ac)); - ac.type = D_CONST; - ac.index = D_NONE; - ac.offset = 0; // fill in - - // huge strings are made static to avoid long names. - if(len > 100) { - snprint(namebuf, sizeof(namebuf), ".string.%d", gen++); - ao.sym = lookup(namebuf); - ao.type = D_STATIC; - } else { - if(len > 0 && s[len-1] == '\0') - len--; - tmp.lit.len = len; - memmove(tmp.lit.s, s, len); - tmp.lit.s[len] = '\0'; - len++; - snprint(namebuf, sizeof(namebuf), "\"%Z\"", &tmp.lit); - ao.sym = pkglookup(namebuf, stringpkg); - ao.type = D_EXTERN; - } - *a = ao; - - // only generate data the first time. - if(ao.sym->flags & SymUniq) - return; - ao.sym->flags |= SymUniq; - - data(); - for(w=0; w<len; w+=8) { - p = pc; - gins(ADATA, N, N); - - // DATA s+w, [NSNAME], $"xxx" - p->from = ao; - p->from.offset = w; - - p->from.scale = NSNAME; - if(w+8 > len) - p->from.scale = len-w; - - p->to = ac; - p->to.type = D_SCONST; - p->to.offset = len; - memmove(p->to.sval, s+w, p->from.scale); - } - p = pc; - ggloblsym(ao.sym, len, ao.type == D_EXTERN); - if(ao.type == D_STATIC) - p->from.type = D_STATIC; - text(); + Sym *sym; + + sym = stringsym(s, len); + a->type = D_EXTERN; + a->sym = sym; + a->offset = widthptr+4; // skip header + a->etype = TINT32; } /* @@ -406,76 +363,13 @@ datastring(char *s, int len, Addr *a) void datagostring(Strlit *sval, Addr *a) { - Prog *p; - Addr ac, ao, ap; - int32 wi, wp; - static int gen; - - memset(&ac, 0, sizeof(ac)); - memset(&ao, 0, sizeof(ao)); - memset(&ap, 0, sizeof(ap)); - - // constant - ac.type = D_CONST; - ac.index = D_NONE; - ac.offset = 0; // fill in - - // string len+ptr - ao.type = D_STATIC; // fill in - ao.index = D_NONE; - ao.etype = TINT32; - ao.sym = nil; // fill in - - // $string len+ptr - datastring(sval->s, sval->len, &ap); - ap.index = ap.type; - ap.type = D_ADDR; - ap.etype = TINT32; - - wi = types[TUINT32]->width; - wp = types[tptr]->width; - - if(ap.index == D_STATIC) { - // huge strings are made static to avoid long names - snprint(namebuf, sizeof(namebuf), ".gostring.%d", ++gen); - ao.sym = lookup(namebuf); - ao.type = D_STATIC; - } else { - // small strings get named by their contents, - // so that multiple modules using the same string - // can share it. - snprint(namebuf, sizeof(namebuf), "\"%Z\"", sval); - ao.sym = pkglookup(namebuf, gostringpkg); - ao.type = D_EXTERN; - } - - *a = ao; - if(ao.sym->flags & SymUniq) - return; - ao.sym->flags |= SymUniq; - - data(); - // DATA gostring, wp, $cstring - p = pc; - gins(ADATA, N, N); - p->from = ao; - p->from.scale = wp; - p->to = ap; - - // DATA gostring+wp, wi, $len - p = pc; - gins(ADATA, N, N); - p->from = ao; - p->from.offset = wp; - p->from.scale = wi; - p->to = ac; - p->to.offset = sval->len; - - p = pc; - ggloblsym(ao.sym, types[TSTRING]->width, ao.type == D_EXTERN); - if(ao.type == D_STATIC) - p->from.type = D_STATIC; - text(); + Sym *sym; + + sym = stringsym(sval->s, sval->len); + a->type = D_EXTERN; + a->sym = sym; + a->offset = 0; // header + a->etype = TINT32; } void diff --git a/src/cmd/6l/asm.c b/src/cmd/6l/asm.c index fb041d83a..ba2074fde 100644 --- a/src/cmd/6l/asm.c +++ b/src/cmd/6l/asm.c @@ -1101,32 +1101,34 @@ genasmsym(void (*put)(Sym*, char*, int, vlong, vlong, int, Sym*)) { Auto *a; Sym *s; - int h; - for(h=0; h<NHASH; h++) { - for(s=hash[h]; s!=S; s=s->hash) { - switch(s->type&~SSUB) { - case SCONST: - case SRODATA: - case SDATA: - case SELFDATA: - case SMACHOGOT: - case SWINDOWS: - if(!s->reachable) - continue; - put(s, s->name, 'D', symaddr(s), s->size, s->version, s->gotype); + for(s=allsym; s!=S; s=s->allsym) { + if(s->hide) + continue; + switch(s->type&~SSUB) { + case SCONST: + case SRODATA: + case SDATA: + case SELFDATA: + case SMACHOGOT: + case STYPE: + case SSTRING: + case SGOSTRING: + case SWINDOWS: + if(!s->reachable) continue; + put(s, s->name, 'D', symaddr(s), s->size, s->version, s->gotype); + continue; - case SBSS: - if(!s->reachable) - continue; - put(s, s->name, 'B', symaddr(s), s->size, s->version, s->gotype); + case SBSS: + if(!s->reachable) continue; + put(s, s->name, 'B', symaddr(s), s->size, s->version, s->gotype); + continue; - case SFILE: - put(nil, s->name, 'f', s->value, 0, s->version, 0); - continue; - } + case SFILE: + put(nil, s->name, 'f', s->value, 0, s->version, 0); + continue; } } diff --git a/src/cmd/6l/l.h b/src/cmd/6l/l.h index 6933d8eb1..4fc13b94a 100644 --- a/src/cmd/6l/l.h +++ b/src/cmd/6l/l.h @@ -132,11 +132,13 @@ struct Sym uchar dynexport; uchar special; uchar stkcheck; + uchar hide; int32 dynid; int32 sig; int32 plt; int32 got; Sym* hash; // in hash table + Sym* allsym; // in all symbol list Sym* next; // in text or data list Sym* sub; // in SSUB list Sym* outer; // container of sub @@ -177,29 +179,6 @@ struct Movtab enum { - Sxxx, - - /* order here is order in output file */ - STEXT = 1, - SELFDATA, - SMACHOPLT, - SRODATA, - SDATA, - SMACHOGOT, - SWINDOWS, - SBSS, - - SXREF, - SMACHODYNSTR, - SMACHODYNSYM, - SMACHOINDIRECTPLT, - SMACHOINDIRECTGOT, - SFILE, - SCONST, - SDYNIMPORT, - SSUB = 1<<8, - - NHASH = 10007, MINSIZ = 8, STRINGSZ = 200, MINLC = 1, diff --git a/src/cmd/6l/obj.c b/src/cmd/6l/obj.c index f113e3ec1..6b43d2df4 100644 --- a/src/cmd/6l/obj.c +++ b/src/cmd/6l/obj.c @@ -287,7 +287,7 @@ zsym(char *pn, Biobuf *f, Sym *h[]) { int o; - o = Bgetc(f); + o = BGETC(f); if(o < 0 || o >= NSYM || h[o] == nil) mangle(pn); return h[o]; @@ -301,12 +301,12 @@ zaddr(char *pn, Biobuf *f, Adr *a, Sym *h[]) Sym *s; Auto *u; - t = Bgetc(f); + t = BGETC(f); a->index = D_NONE; a->scale = 0; if(t & T_INDEX) { - a->index = Bgetc(f); - a->scale = Bgetc(f); + a->index = BGETC(f); + a->scale = BGETC(f); } a->offset = 0; if(t & T_OFFSET) { @@ -330,7 +330,7 @@ zaddr(char *pn, Biobuf *f, Adr *a, Sym *h[]) a->type = D_SCONST; } if(t & T_TYPE) - a->type = Bgetc(f); + a->type = BGETC(f); if(a->type < 0 || a->type >= D_SIZE) mangle(pn); adrgotype = S; @@ -405,10 +405,10 @@ newloop: loop: if(f->state == Bracteof || Boffset(f) >= eof) goto eof; - o = Bgetc(f); + o = BGETC(f); if(o == Beof) goto eof; - o |= Bgetc(f) << 8; + o |= BGETC(f) << 8; if(o <= AXXX || o >= ALAST) { if(o < 0) goto eof; @@ -421,8 +421,8 @@ loop: sig = 0; if(o == ASIGNAME) sig = Bget4(f); - v = Bgetc(f); /* type */ - o = Bgetc(f); /* sym */ + v = BGETC(f); /* type */ + o = BGETC(f); /* sym */ r = 0; if(v == D_STATIC) r = version; diff --git a/src/cmd/6l/pass.c b/src/cmd/6l/pass.c index 8fda94392..0b0ee1253 100644 --- a/src/cmd/6l/pass.c +++ b/src/cmd/6l/pass.c @@ -274,10 +274,10 @@ patch(void) if(HEADTYPE == Hwindows) { // Windows // Convert - // op n(GS), reg + // op n(GS), reg // to // MOVL 0x58(GS), reg - // op n(reg), reg + // op n(reg), reg // The purpose of this patch is to fix some accesses // to extern register variables (TLS) on Windows, as // a different method is used to access them. @@ -674,6 +674,11 @@ dostkoff(void) p->spadj = -autoffset; p = appendp(p); p->as = ARET; + // If there are instructions following + // this ARET, they come from a branch + // with the same stackframe, so undo + // the cleanup. + p->spadj = +autoffset; } } } diff --git a/src/cmd/8c/peep.c b/src/cmd/8c/peep.c index 9e18fc94d..9511a5579 100644 --- a/src/cmd/8c/peep.c +++ b/src/cmd/8c/peep.c @@ -713,7 +713,7 @@ copyu(Prog *p, Adr *v, Adr *s) return 3; case ACALL: /* funny */ - if(REGARG >= 0 && v->type == REGARG) + if(REGARG >= 0 && v->type == (uchar)REGARG) return 2; if(s != A) { diff --git a/src/cmd/8c/txt.c b/src/cmd/8c/txt.c index 0dd387d11..b2e0148a0 100644 --- a/src/cmd/8c/txt.c +++ b/src/cmd/8c/txt.c @@ -397,6 +397,10 @@ regsalloc(Node *n, Node *nn) void regaalloc1(Node *n, Node *nn) { + if(REGARG < 0) { + fatal(n, "regaalloc1 and REGARG<0"); + return; + } nodreg(n, nn, REGARG); reg[REGARG]++; curarg = align(curarg, nn->type, Aarg1, nil); diff --git a/src/cmd/8g/ggen.c b/src/cmd/8g/ggen.c index 4dcbd4489..8db552493 100644 --- a/src/cmd/8g/ggen.c +++ b/src/cmd/8g/ggen.c @@ -32,7 +32,7 @@ compile(Node *fn) return; // set up domain for labels - labellist = L; + clearlabels(); lno = setlineno(fn); diff --git a/src/cmd/8g/gobj.c b/src/cmd/8g/gobj.c index e48ad529b..bc1dfe8bf 100644 --- a/src/cmd/8g/gobj.c +++ b/src/cmd/8g/gobj.c @@ -320,6 +320,24 @@ dumpdata(void) pc = estrdat; } +int +dsname(Sym *s, int off, char *t, int n) +{ + Prog *p; + + p = gins(ADATA, N, N); + p->from.type = D_EXTERN; + p->from.index = D_NONE; + p->from.offset = off; + p->from.scale = n; + p->from.sym = s; + + p->to.type = D_SCONST; + p->to.index = D_NONE; + memmove(p->to.sval, t, n); + return off + n; +} + /* * make a refer to the data s, s+len * emitting DATA if needed. @@ -327,74 +345,13 @@ dumpdata(void) void datastring(char *s, int len, Addr *a) { - int w; - Prog *p; - Addr ac, ao; - static int gen; - struct { - Strlit lit; - char buf[100]; - } tmp; - - // string - memset(&ao, 0, sizeof(ao)); - ao.type = D_STATIC; - ao.index = D_NONE; - ao.etype = TINT32; - ao.offset = 0; // fill in - - // constant - memset(&ac, 0, sizeof(ac)); - ac.type = D_CONST; - ac.index = D_NONE; - ac.offset = 0; // fill in - - // huge strings are made static to avoid long names. - if(len > 100) { - snprint(namebuf, sizeof(namebuf), ".string.%d", gen++); - ao.sym = lookup(namebuf); - ao.type = D_STATIC; - } else { - if(len > 0 && s[len-1] == '\0') - len--; - tmp.lit.len = len; - memmove(tmp.lit.s, s, len); - tmp.lit.s[len] = '\0'; - len++; - snprint(namebuf, sizeof(namebuf), "\"%Z\"", &tmp.lit); - ao.sym = pkglookup(namebuf, stringpkg); - ao.type = D_EXTERN; - } - *a = ao; - - // only generate data the first time. - if(ao.sym->flags & SymUniq) - return; - ao.sym->flags |= SymUniq; - - data(); - for(w=0; w<len; w+=8) { - p = pc; - gins(ADATA, N, N); - - // DATA s+w, [NSNAME], $"xxx" - p->from = ao; - p->from.offset = w; - - p->from.scale = NSNAME; - if(w+8 > len) - p->from.scale = len-w; - - p->to = ac; - p->to.type = D_SCONST; - p->to.offset = len; - memmove(p->to.sval, s+w, p->from.scale); - } - p = pc; - ggloblsym(ao.sym, len, ao.type == D_EXTERN); - if(ao.type == D_STATIC) - p->from.type = D_STATIC; - text(); + Sym *sym; + + sym = stringsym(s, len); + a->type = D_EXTERN; + a->sym = sym; + a->offset = widthptr+4; // skip header + a->etype = TINT32; } /* @@ -404,76 +361,13 @@ datastring(char *s, int len, Addr *a) void datagostring(Strlit *sval, Addr *a) { - Prog *p; - Addr ac, ao, ap; - int32 wi, wp; - static int gen; - - memset(&ac, 0, sizeof(ac)); - memset(&ao, 0, sizeof(ao)); - memset(&ap, 0, sizeof(ap)); - - // constant - ac.type = D_CONST; - ac.index = D_NONE; - ac.offset = 0; // fill in - - // string len+ptr - ao.type = D_STATIC; // fill in - ao.index = D_NONE; - ao.etype = TINT32; - ao.sym = nil; // fill in - - // $string len+ptr - datastring(sval->s, sval->len, &ap); - ap.index = ap.type; - ap.type = D_ADDR; - ap.etype = TINT32; - - wi = types[TUINT32]->width; - wp = types[tptr]->width; - - if(ap.index == D_STATIC) { - // huge strings are made static to avoid long names - snprint(namebuf, sizeof(namebuf), ".gostring.%d", ++gen); - ao.sym = lookup(namebuf); - ao.type = D_STATIC; - } else { - // small strings get named by their contents, - // so that multiple modules using the same string - // can share it. - snprint(namebuf, sizeof(namebuf), "\"%Z\"", sval); - ao.sym = pkglookup(namebuf, gostringpkg); - ao.type = D_EXTERN; - } - - *a = ao; - if(ao.sym->flags & SymUniq) - return; - ao.sym->flags |= SymUniq; - - data(); - // DATA gostring, wp, $cstring - p = pc; - gins(ADATA, N, N); - p->from = ao; - p->from.scale = wp; - p->to = ap; - - // DATA gostring+wp, wi, $len - p = pc; - gins(ADATA, N, N); - p->from = ao; - p->from.offset = wp; - p->from.scale = wi; - p->to = ac; - p->to.offset = sval->len; - - p = pc; - ggloblsym(ao.sym, types[TSTRING]->width, ao.type == D_EXTERN); - if(ao.type == D_STATIC) - p->from.type = D_STATIC; - text(); + Sym *sym; + + sym = stringsym(sval->s, sval->len); + a->type = D_EXTERN; + a->sym = sym; + a->offset = 0; // header + a->etype = TINT32; } void diff --git a/src/cmd/8g/peep.c b/src/cmd/8g/peep.c index 580b1a922..5ad29e1b2 100644 --- a/src/cmd/8g/peep.c +++ b/src/cmd/8g/peep.c @@ -120,6 +120,25 @@ peep(void) p = p->link; } } + + // movb elimination. + // movb is simulated by the linker + // when a register other than ax, bx, cx, dx + // is used, so rewrite to other instructions + // when possible. a movb into a register + // can smash the entire 32-bit register without + // causing any trouble. + for(r=firstr; r!=R; r=r->link) { + p = r->prog; + if(p->as == AMOVB && regtyp(&p->to)) { + // movb into register. + // from another register or constant can be movl. + if(regtyp(&p->from) || p->from.type == D_CONST) + p->as = AMOVL; + else + p->as = AMOVBLZX; + } + } // constant propagation // find MOV $con,R followed by @@ -152,6 +171,8 @@ loop1: for(r=firstr; r!=R; r=r->link) { p = r->prog; switch(p->as) { + case AMOVB: + case AMOVW: case AMOVL: if(regtyp(&p->to)) if(regtyp(&p->from)) { @@ -182,6 +203,7 @@ loop1: } break; + case AADDB: case AADDL: case AADDW: if(p->from.type != D_CONST || needc(p->link)) @@ -204,6 +226,7 @@ loop1: } break; + case ASUBB: case ASUBL: case ASUBW: if(p->from.type != D_CONST || needc(p->link)) @@ -380,6 +403,8 @@ subprop(Reg *r0) case AMOVSL: return 0; + case AMOVB: + case AMOVW: case AMOVL: if(p->to.type == v1->type) goto gotit; @@ -560,6 +585,8 @@ copyu(Prog *p, Adr *v, Adr *s) case ANOP: /* rhs store */ + case AMOVB: + case AMOVW: case AMOVL: case AMOVBLSX: case AMOVBLZX: @@ -624,8 +651,6 @@ copyu(Prog *p, Adr *v, Adr *s) case AXORB: case AXORL: case AXORW: - case AMOVB: - case AMOVW: if(copyas(&p->to, v)) return 2; goto caseread; diff --git a/src/cmd/8l/asm.c b/src/cmd/8l/asm.c index 1e760d89e..b9bd0dae9 100644 --- a/src/cmd/8l/asm.c +++ b/src/cmd/8l/asm.c @@ -307,6 +307,7 @@ elfsetupplt(void) int archreloc(Reloc *r, Sym *s, vlong *val) { + USED(s); switch(r->type) { case D_CONST: *val = r->add; @@ -644,7 +645,7 @@ asmb(void) { int32 v, magic; int a, dynsym; - uint32 va, fo, w, symo, startva, machlink; + uint32 symo, startva, machlink; ElfEhdr *eh; ElfPhdr *ph, *pph; ElfShdr *sh; @@ -776,7 +777,6 @@ asmb(void) lputb(0L); lputb(~0L); /* gp value ?? */ break; - lputl(0); /* x */ case Hunixcoff: /* unix coff */ /* * file header @@ -892,13 +892,10 @@ asmb(void) debug['d'] = 1; eh = getElfEhdr(); - fo = HEADR; startva = INITTEXT - HEADR; - va = startva + fo; - w = segtext.filelen; /* This null SHdr must appear before all others */ - sh = newElfShdr(elfstr[ElfStrEmpty]); + newElfShdr(elfstr[ElfStrEmpty]); /* program header info */ pph = newElfPhdr(); @@ -1158,6 +1155,8 @@ genasmsym(void (*put)(Sym*, char*, int, vlong, vlong, int, Sym*)) for(h=0; h<NHASH; h++) { for(s=hash[h]; s!=S; s=s->hash) { + if(s->hide) + continue; switch(s->type&~SSUB) { case SCONST: case SRODATA: @@ -1165,6 +1164,9 @@ genasmsym(void (*put)(Sym*, char*, int, vlong, vlong, int, Sym*)) case SELFDATA: case SMACHO: case SMACHOGOT: + case STYPE: + case SSTRING: + case SGOSTRING: case SWINDOWS: if(!s->reachable) continue; @@ -1209,6 +1211,6 @@ genasmsym(void (*put)(Sym*, char*, int, vlong, vlong, int, Sym*)) put(nil, a->asym->name, 'p', a->aoffset, 0, 0, a->gotype); } if(debug['v'] || debug['n']) - Bprint(&bso, "symsize = %ud\n", symsize); + Bprint(&bso, "symsize = %uld\n", symsize); Bflush(&bso); } diff --git a/src/cmd/8l/l.h b/src/cmd/8l/l.h index e4650ee58..ac0f3953f 100644 --- a/src/cmd/8l/l.h +++ b/src/cmd/8l/l.h @@ -131,6 +131,7 @@ struct Sym uchar dynexport; uchar special; uchar stkcheck; + uchar hide; int32 value; int32 size; int32 sig; @@ -138,6 +139,7 @@ struct Sym int32 plt; int32 got; Sym* hash; // in hash table + Sym* allsym; // in all symbol list Sym* next; // in text or data list Sym* sub; // in sub list Sym* outer; // container of sub @@ -168,31 +170,6 @@ struct Optab enum { - Sxxx, - - /* order here is order in output file */ - STEXT, - SELFDATA, - SMACHOPLT, - SRODATA, - SDATA, - SMACHO, /* Mach-O __nl_symbol_ptr */ - SMACHOGOT, - SWINDOWS, - SBSS, - - SXREF, - SMACHODYNSTR, - SMACHODYNSYM, - SMACHOINDIRECTPLT, - SMACHOINDIRECTGOT, - SFILE, - SCONST, - SDYNIMPORT, - - SSUB = 1<<8, /* sub-symbol, linked from parent via ->sub list */ - - NHASH = 10007, MINSIZ = 4, STRINGSZ = 200, MINLC = 1, diff --git a/src/cmd/8l/pass.c b/src/cmd/8l/pass.c index 294926f29..28589b66a 100644 --- a/src/cmd/8l/pass.c +++ b/src/cmd/8l/pass.c @@ -614,14 +614,17 @@ dostkoff(void) diag("unbalanced PUSH/POP"); if(autoffset) { - q = p; + p->as = AADJSP; + p->from.type = D_CONST; + p->from.offset = -autoffset; + p->spadj = -autoffset; p = appendp(p); p->as = ARET; - - q->as = AADJSP; - q->from.type = D_CONST; - q->from.offset = -autoffset; - p->spadj = -autoffset; + // If there are instructions following + // this ARET, they come from a branch + // with the same stackframe, so undo + // the cleanup. + p->spadj = +autoffset; } } } diff --git a/src/cmd/Makefile b/src/cmd/Makefile index 104e9f5df..5a37733de 100644 --- a/src/cmd/Makefile +++ b/src/cmd/Makefile @@ -18,7 +18,7 @@ DIRS=\ gc\ godefs\ gopack\ - gotest\ + gotry\ nm\ prof\ @@ -41,8 +41,11 @@ CLEANDIRS=\ cgo\ ebnflint\ godoc\ + gofix\ gofmt\ goinstall\ + gotest\ + gotype\ goyacc\ hgpatch\ diff --git a/src/cmd/cgo/ast.go b/src/cmd/cgo/ast.go index 2eae22aed..46e33686d 100644 --- a/src/cmd/cgo/ast.go +++ b/src/cmd/cgo/ast.go @@ -30,7 +30,7 @@ func parse(name string, flags uint) *ast.File { } os.Exit(2) } - fatal("parsing %s: %s", name, err) + fatalf("parsing %s: %s", name, err) } return ast1 } @@ -180,7 +180,7 @@ func (f *File) saveExport(x interface{}, context string) { return } for _, c := range n.Doc.List { - if string(c.Text[0:9]) != "//export " { + if !strings.HasPrefix(string(c.Text), "//export ") { continue } @@ -325,26 +325,28 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{} f.walk(n.Results, "expr", visit) case *ast.BranchStmt: case *ast.BlockStmt: - f.walk(n.List, "stmt", visit) + f.walk(n.List, context, visit) case *ast.IfStmt: f.walk(n.Init, "stmt", visit) f.walk(&n.Cond, "expr", visit) f.walk(n.Body, "stmt", visit) f.walk(n.Else, "stmt", visit) case *ast.CaseClause: - f.walk(n.Values, "expr", visit) + if context == "typeswitch" { + context = "type" + } else { + context = "expr" + } + f.walk(n.List, context, visit) f.walk(n.Body, "stmt", visit) case *ast.SwitchStmt: f.walk(n.Init, "stmt", visit) f.walk(&n.Tag, "expr", visit) - f.walk(n.Body, "stmt", visit) - case *ast.TypeCaseClause: - f.walk(n.Types, "type", visit) - f.walk(n.Body, "stmt", visit) + f.walk(n.Body, "switch", visit) case *ast.TypeSwitchStmt: f.walk(n.Init, "stmt", visit) f.walk(n.Assign, "stmt", visit) - f.walk(n.Body, "stmt", visit) + f.walk(n.Body, "typeswitch", visit) case *ast.CommClause: f.walk(n.Comm, "stmt", visit) f.walk(n.Body, "stmt", visit) diff --git a/src/cmd/cgo/gcc.go b/src/cmd/cgo/gcc.go index f7ecc9e14..ae5ca2c7d 100644 --- a/src/cmd/cgo/gcc.go +++ b/src/cmd/cgo/gcc.go @@ -79,7 +79,7 @@ NextLine: l = strings.TrimSpace(l[4:]) fields := strings.Split(l, ":", 2) if len(fields) != 2 { - fatal("%s: bad #cgo line: %s", srcfile, line) + fatalf("%s: bad #cgo line: %s", srcfile, line) } var k string @@ -97,17 +97,17 @@ NextLine: continue NextLine } default: - fatal("%s: bad #cgo option: %s", srcfile, fields[0]) + fatalf("%s: bad #cgo option: %s", srcfile, fields[0]) } if k != "CFLAGS" && k != "LDFLAGS" { - fatal("%s: unsupported #cgo option %s", srcfile, k) + fatalf("%s: unsupported #cgo option %s", srcfile, k) } v := strings.TrimSpace(fields[1]) args, err := splitQuoted(v) if err != nil { - fatal("%s: bad #cgo option %s: %s", srcfile, k, err.String()) + fatalf("%s: bad #cgo option %s: %s", srcfile, k, err.String()) } if oldv, ok := p.CgoFlags[k]; ok { p.CgoFlags[k] = oldv + " " + v @@ -317,7 +317,7 @@ func (p *Package) guessKinds(f *File) []*Name { b.WriteString("}\n") stderr := p.gccErrors(b.Bytes()) if stderr == "" { - fatal("gcc produced no output\non input:\n%s", b.Bytes()) + fatalf("gcc produced no output\non input:\n%s", b.Bytes()) } names := make([]*Name, len(toSniff)) @@ -383,7 +383,7 @@ func (p *Package) guessKinds(f *File) []*Name { error(token.NoPos, "could not determine kind of name for C.%s", n.Go) } if nerrors > 0 { - fatal("unresolved names") + fatalf("unresolved names") } return needType } @@ -422,7 +422,7 @@ func (p *Package) loadDWARF(f *File, names []*Name) { for { e, err := r.Next() if err != nil { - fatal("reading DWARF entry: %s", err) + fatalf("reading DWARF entry: %s", err) } if e == nil { break @@ -433,7 +433,7 @@ func (p *Package) loadDWARF(f *File, names []*Name) { for { e, err := r.Next() if err != nil { - fatal("reading DWARF entry: %s", err) + fatalf("reading DWARF entry: %s", err) } if e.Tag == 0 { break @@ -452,27 +452,27 @@ func (p *Package) loadDWARF(f *File, names []*Name) { name, _ := e.Val(dwarf.AttrName).(string) typOff, _ := e.Val(dwarf.AttrType).(dwarf.Offset) if name == "" || typOff == 0 { - fatal("malformed DWARF TagVariable entry") + fatalf("malformed DWARF TagVariable entry") } if !strings.HasPrefix(name, "__cgo__") { break } typ, err := d.Type(typOff) if err != nil { - fatal("loading DWARF type: %s", err) + fatalf("loading DWARF type: %s", err) } t, ok := typ.(*dwarf.PtrType) if !ok || t == nil { - fatal("internal error: %s has non-pointer type", name) + fatalf("internal error: %s has non-pointer type", name) } i, err := strconv.Atoi(name[7:]) if err != nil { - fatal("malformed __cgo__ name: %s", name) + fatalf("malformed __cgo__ name: %s", name) } if enums[i] != 0 { t, err := d.Type(enums[i]) if err != nil { - fatal("loading DWARF type: %s", err) + fatalf("loading DWARF type: %s", err) } types[i] = t } else { @@ -632,14 +632,14 @@ func (p *Package) gccDebug(stdin []byte) *dwarf.Data { if f, err = elf.Open(gccTmp); err != nil { if f, err = macho.Open(gccTmp); err != nil { if f, err = pe.Open(gccTmp); err != nil { - fatal("cannot parse gcc output %s as ELF or Mach-O or PE object", gccTmp) + fatalf("cannot parse gcc output %s as ELF or Mach-O or PE object", gccTmp) } } } d, err := f.DWARF() if err != nil { - fatal("cannot load DWARF debug information from %s: %s", gccTmp, err) + fatalf("cannot load DWARF debug information from %s: %s", gccTmp, err) } return d } @@ -807,7 +807,7 @@ func (tr *TypeRepr) Set(repr string, fargs ...interface{}) { func (c *typeConv) Type(dtype dwarf.Type) *Type { if t, ok := c.m[dtype]; ok { if t.Go == nil { - fatal("type conversion loop at %s", dtype) + fatalf("type conversion loop at %s", dtype) } return t } @@ -830,11 +830,11 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { switch dt := dtype.(type) { default: - fatal("unexpected type: %s", dtype) + fatalf("unexpected type: %s", dtype) case *dwarf.AddrType: if t.Size != c.ptrSize { - fatal("unexpected: %d-byte address type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte address type - %s", t.Size, dtype) } t.Go = c.uintptr t.Align = t.Size @@ -860,7 +860,7 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { case *dwarf.CharType: if t.Size != 1 { - fatal("unexpected: %d-byte char type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte char type - %s", t.Size, dtype) } t.Go = c.int8 t.Align = 1 @@ -880,7 +880,7 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { } switch t.Size + int64(signed) { default: - fatal("unexpected: %d-byte enum type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte enum type - %s", t.Size, dtype) case 1: t.Go = c.uint8 case 2: @@ -902,7 +902,7 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { case *dwarf.FloatType: switch t.Size { default: - fatal("unexpected: %d-byte float type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte float type - %s", t.Size, dtype) case 4: t.Go = c.float32 case 8: @@ -915,7 +915,7 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { case *dwarf.ComplexType: switch t.Size { default: - fatal("unexpected: %d-byte complex type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte complex type - %s", t.Size, dtype) case 8: t.Go = c.complex64 case 16: @@ -933,11 +933,11 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { case *dwarf.IntType: if dt.BitSize > 0 { - fatal("unexpected: %d-bit int type - %s", dt.BitSize, dtype) + fatalf("unexpected: %d-bit int type - %s", dt.BitSize, dtype) } switch t.Size { default: - fatal("unexpected: %d-byte int type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte int type - %s", t.Size, dtype) case 1: t.Go = c.int8 case 2: @@ -1022,18 +1022,18 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { case *dwarf.UcharType: if t.Size != 1 { - fatal("unexpected: %d-byte uchar type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte uchar type - %s", t.Size, dtype) } t.Go = c.uint8 t.Align = 1 case *dwarf.UintType: if dt.BitSize > 0 { - fatal("unexpected: %d-bit uint type - %s", dt.BitSize, dtype) + fatalf("unexpected: %d-bit uint type - %s", dt.BitSize, dtype) } switch t.Size { default: - fatal("unexpected: %d-byte uint type - %s", t.Size, dtype) + fatalf("unexpected: %d-byte uint type - %s", t.Size, dtype) case 1: t.Go = c.uint8 case 2: @@ -1067,7 +1067,7 @@ func (c *typeConv) Type(dtype dwarf.Type) *Type { } if t.C.Empty() { - fatal("internal error: did not create C name for %s", dtype) + fatalf("internal error: did not create C name for %s", dtype) } return t @@ -1156,7 +1156,7 @@ func (c *typeConv) Opaque(n int64) ast.Expr { func (c *typeConv) intExpr(n int64) ast.Expr { return &ast.BasicLit{ Kind: token.INT, - Value: []byte(strconv.Itoa64(n)), + Value: strconv.Itoa64(n), } } @@ -1229,7 +1229,7 @@ func (c *typeConv) Struct(dt *dwarf.StructType) (expr *ast.StructType, csyntax s off = dt.ByteSize } if off != dt.ByteSize { - fatal("struct size calculation error") + fatalf("struct size calculation error") } buf.WriteString("}") csyntax = buf.String() diff --git a/src/cmd/cgo/main.go b/src/cmd/cgo/main.go index 2dc662de5..00ffc4506 100644 --- a/src/cmd/cgo/main.go +++ b/src/cmd/cgo/main.go @@ -177,11 +177,11 @@ func main() { arch := os.Getenv("GOARCH") if arch == "" { - fatal("$GOARCH is not set") + fatalf("$GOARCH is not set") } ptrSize := ptrSizeMap[arch] if ptrSize == 0 { - fatal("unknown $GOARCH %q", arch) + fatalf("unknown $GOARCH %q", arch) } // Clear locale variables so gcc emits English errors [sic]. @@ -203,9 +203,9 @@ func main() { // Use the beginning of the md5 of the input to disambiguate. h := md5.New() for _, input := range goFiles { - f, err := os.Open(input, os.O_RDONLY, 0) + f, err := os.Open(input) if err != nil { - fatal("%s", err) + fatalf("%s", err) } io.Copy(h, f) f.Close() diff --git a/src/cmd/cgo/out.go b/src/cmd/cgo/out.go index 4a5fa6a73..abf8c8bc2 100644 --- a/src/cmd/cgo/out.go +++ b/src/cmd/cgo/out.go @@ -105,7 +105,7 @@ func dynimport(obj string) (syms, imports []string) { if f, err1 = elf.Open(obj); err1 != nil { if f, err2 = pe.Open(obj); err2 != nil { if f, err3 = macho.Open(obj); err3 != nil { - fatal("cannot parse %s as ELF (%v) or PE (%v) or Mach-O (%v)", obj, err1, err2, err3) + fatalf("cannot parse %s as ELF (%v) or PE (%v) or Mach-O (%v)", obj, err1, err2, err3) } isMacho = true } @@ -114,7 +114,7 @@ func dynimport(obj string) (syms, imports []string) { var err os.Error syms, err = f.ImportedSymbols() if err != nil { - fatal("cannot load dynamic symbols: %v", err) + fatalf("cannot load dynamic symbols: %v", err) } if isMacho { // remove leading _ that OS X insists on @@ -127,7 +127,7 @@ func dynimport(obj string) (syms, imports []string) { imports, err = f.ImportedLibraries() if err != nil { - fatal("cannot load dynamic imports: %v", err) + fatalf("cannot load dynamic imports: %v", err) } return diff --git a/src/cmd/cgo/util.go b/src/cmd/cgo/util.go index 59529a6d2..1ca24103e 100644 --- a/src/cmd/cgo/util.go +++ b/src/cmd/cgo/util.go @@ -18,23 +18,23 @@ import ( func run(stdin []byte, argv []string) (stdout, stderr []byte, ok bool) { cmd, err := exec.LookPath(argv[0]) if err != nil { - fatal("exec %s: %s", argv[0], err) + fatalf("exec %s: %s", argv[0], err) } r0, w0, err := os.Pipe() if err != nil { - fatal("%s", err) + fatalf("%s", err) } r1, w1, err := os.Pipe() if err != nil { - fatal("%s", err) + fatalf("%s", err) } r2, w2, err := os.Pipe() if err != nil { - fatal("%s", err) + fatalf("%s", err) } - p, err := os.StartProcess(cmd, argv, os.Environ(), "", []*os.File{r0, w1, w2}) + p, err := os.StartProcess(cmd, argv, &os.ProcAttr{Files: []*os.File{r0, w1, w2}}) if err != nil { - fatal("%s", err) + fatalf("%s", err) } defer p.Release() r0.Close() @@ -58,14 +58,14 @@ func run(stdin []byte, argv []string) (stdout, stderr []byte, ok bool) { w, err := p.Wait(0) if err != nil { - fatal("%s", err) + fatalf("%s", err) } ok = w.Exited() && w.ExitStatus() == 0 return } // Die with an error message. -func fatal(msg string, args ...interface{}) { +func fatalf(msg string, args ...interface{}) { fmt.Fprintf(os.Stderr, msg+"\n", args...) os.Exit(2) } @@ -95,9 +95,9 @@ func isName(s string) bool { } func creat(name string) *os.File { - f, err := os.Open(name, os.O_WRONLY|os.O_CREAT|os.O_TRUNC, 0666) + f, err := os.Create(name) if err != nil { - fatal("%s", err) + fatalf("%s", err) } return f } diff --git a/src/cmd/gc/align.c b/src/cmd/gc/align.c index 833eba19a..a01e2ea46 100644 --- a/src/cmd/gc/align.c +++ b/src/cmd/gc/align.c @@ -172,6 +172,9 @@ dowidth(Type *t) w = 8; checkwidth(t->type); break; + case TUNSAFEPTR: + w = widthptr; + break; case TINTER: // implemented as 2 pointers w = 2*widthptr; t->align = widthptr; @@ -400,6 +403,13 @@ typeinit(void) types[TPTR64] = typ(TPTR64); dowidth(types[TPTR64]); + + t = typ(TUNSAFEPTR); + types[TUNSAFEPTR] = t; + t->sym = pkglookup("Pointer", unsafepkg); + t->sym->def = typenod(t); + + dowidth(types[TUNSAFEPTR]); tptr = TPTR32; if(widthptr == 8) @@ -481,6 +491,7 @@ typeinit(void) okforeq[TPTR32] = 1; okforeq[TPTR64] = 1; + okforeq[TUNSAFEPTR] = 1; okforeq[TINTER] = 1; okforeq[TMAP] = 1; okforeq[TCHAN] = 1; @@ -570,6 +581,7 @@ typeinit(void) simtype[TMAP] = tptr; simtype[TCHAN] = tptr; simtype[TFUNC] = tptr; + simtype[TUNSAFEPTR] = tptr; /* pick up the backend typedefs */ for(i=0; typedefs[i].name; i++) { diff --git a/src/cmd/gc/builtin.c.boot b/src/cmd/gc/builtin.c.boot index 48f45293f..bdbca7f78 100644 --- a/src/cmd/gc/builtin.c.boot +++ b/src/cmd/gc/builtin.c.boot @@ -66,15 +66,17 @@ char *runtimeimport = "func \"\".mapiter2 (hiter *any) (key any, val any)\n" "func \"\".makechan (elem *uint8, hint int64) chan any\n" "func \"\".chanrecv1 (hchan <-chan any) any\n" - "func \"\".chanrecv3 (hchan <-chan any) (elem any, closed bool)\n" + "func \"\".chanrecv2 (hchan <-chan any) (elem any, received bool)\n" "func \"\".chansend1 (hchan chan<- any, elem any)\n" "func \"\".closechan (hchan any)\n" "func \"\".closedchan (hchan any) bool\n" "func \"\".selectnbsend (hchan chan<- any, elem any) bool\n" "func \"\".selectnbrecv (elem *any, hchan <-chan any) bool\n" + "func \"\".selectnbrecv2 (elem *any, received *bool, hchan <-chan any) bool\n" "func \"\".newselect (size int) *uint8\n" "func \"\".selectsend (sel *uint8, hchan chan<- any, elem any) bool\n" "func \"\".selectrecv (sel *uint8, hchan <-chan any, elem *any) bool\n" + "func \"\".selectrecv2 (sel *uint8, hchan <-chan any, elem *any, received *bool) bool\n" "func \"\".selectdefault (sel *uint8) bool\n" "func \"\".selectgo (sel *uint8)\n" "func \"\".block ()\n" @@ -96,7 +98,7 @@ char *runtimeimport = "$$\n"; char *unsafeimport = "package unsafe\n" - "type \"\".Pointer *any\n" + "type \"\".Pointer uintptr\n" "func \"\".Offsetof (? any) int\n" "func \"\".Sizeof (? any) int\n" "func \"\".Alignof (? any) int\n" diff --git a/src/cmd/gc/const.c b/src/cmd/gc/const.c index a54c40f6c..a36ec68c0 100644 --- a/src/cmd/gc/const.c +++ b/src/cmd/gc/const.c @@ -136,7 +136,6 @@ convlit1(Node **np, Type *t, int explicit) case CTNIL: switch(et) { default: - yyerror("cannot use nil as %T", t); n->type = T; goto bad; @@ -155,6 +154,7 @@ convlit1(Node **np, Type *t, int explicit) case TMAP: case TCHAN: case TFUNC: + case TUNSAFEPTR: break; } break; diff --git a/src/cmd/gc/dcl.c b/src/cmd/gc/dcl.c index a71272aa2..3089a23b0 100644 --- a/src/cmd/gc/dcl.c +++ b/src/cmd/gc/dcl.c @@ -22,7 +22,6 @@ dflag(void) /* * declaration stack & operations */ -static Sym* dclstack; static void dcopy(Sym *a, Sym *b) @@ -656,10 +655,19 @@ typedcl2(Type *pt, Type *t) { Node *n; + // override declaration in unsafe.go for Pointer. + // there is no way in Go code to define unsafe.Pointer + // so we have to supply it. + if(incannedimport && + strcmp(importpkg->name, "unsafe") == 0 && + strcmp(pt->nod->sym->name, "Pointer") == 0) { + t = types[TUNSAFEPTR]; + } + if(pt->etype == TFORW) goto ok; if(!eqtype(pt->orig, t)) - yyerror("inconsistent definition for type %S during import\n\t%lT\n\t%lT", pt->sym, pt, t); + yyerror("inconsistent definition for type %S during import\n\t%lT\n\t%lT", pt->sym, pt->orig, t); return; ok: @@ -684,13 +692,13 @@ ok: * turn a parsed struct into a type */ static Type** -stotype(NodeList *l, int et, Type **t) +stotype(NodeList *l, int et, Type **t, int funarg) { Type *f, *t1, *t2, **t0; Strlit *note; int lno; NodeList *init; - Node *n; + Node *n, *left; char *what; t0 = t; @@ -707,15 +715,18 @@ stotype(NodeList *l, int et, Type **t) if(n->op != ODCLFIELD) fatal("stotype: oops %N\n", n); + left = n->left; + if(funarg && isblank(left)) + left = N; if(n->right != N) { - if(et == TINTER && n->left != N) { + if(et == TINTER && left != N) { // queue resolution of method type for later. // right now all we need is the name list. // avoids cycles for recursive interface types. n->type = typ(TINTERMETH); n->type->nname = n->right; n->right = N; - n->left->type = n->type; + left->type = n->type; queuemethod(n); } else { typecheck(&n->right, Etype); @@ -724,8 +735,8 @@ stotype(NodeList *l, int et, Type **t) *t0 = T; return t0; } - if(n->left != N) - n->left->type = n->type; + if(left != N) + left->type = n->type; n->right = N; if(n->embedded && n->type != T) { t1 = n->type; @@ -763,7 +774,7 @@ stotype(NodeList *l, int et, Type **t) break; } - if(et == TINTER && n->left == N) { + if(et == TINTER && left == N) { // embedded interface - inline the methods if(n->type->etype != TINTER) { if(n->type->etype == TFORW) @@ -796,8 +807,8 @@ stotype(NodeList *l, int et, Type **t) f->width = BADWIDTH; f->isddd = n->isddd; - if(n->left != N && n->left->op == ONAME) { - f->nname = n->left; + if(left != N && left->op == ONAME) { + f->nname = left; f->embedded = n->embedded; f->sym = f->nname->sym; if(importpkg && !exportname(f->sym->name)) @@ -839,7 +850,7 @@ dostruct(NodeList *l, int et) } t = typ(et); t->funarg = funarg; - stotype(l, et, &t->type); + stotype(l, et, &t->type, funarg); if(t->type == T && l != nil) { t->broke = 1; return t; @@ -933,8 +944,6 @@ checkarglist(NodeList *all, int input) t = n; n = N; } - if(isblank(n)) - n = N; if(n != N && n->sym == S) { t = n; n = N; @@ -1151,9 +1160,9 @@ addmethod(Sym *sf, Type *t, int local) } if(d == T) - stotype(list1(n), 0, &pa->method); + stotype(list1(n), 0, &pa->method, 0); else - stotype(list1(n), 0, &d->down); + stotype(list1(n), 0, &d->down, 0); return; } diff --git a/src/cmd/gc/export.c b/src/cmd/gc/export.c index 09b963f27..014f0c5f0 100644 --- a/src/cmd/gc/export.c +++ b/src/cmd/gc/export.c @@ -75,10 +75,15 @@ autoexport(Node *n, int ctxt) static void dumppkg(Pkg *p) { + char *suffix; + if(p == nil || p == localpkg || p->exported) return; p->exported = 1; - Bprint(bout, "\timport %s \"%Z\"\n", p->name, p->path); + suffix = ""; + if(!p->direct) + suffix = " // indirect"; + Bprint(bout, "\timport %s \"%Z\"%s\n", p->name, p->path, suffix); } static void @@ -265,7 +270,8 @@ void dumpexport(void) { NodeList *l; - int32 lno; + int32 i, lno; + Pkg *p; lno = lineno; @@ -277,6 +283,11 @@ dumpexport(void) Bprint(bout, " safe"); Bprint(bout, "\n"); + for(i=0; i<nelem(phash); i++) + for(p=phash[i]; p; p=p->link) + if(p->direct) + dumppkg(p); + for(l=exportlist; l; l=l->next) { lineno = l->n->lineno; dumpsym(l->n->sym); diff --git a/src/cmd/gc/gen.c b/src/cmd/gc/gen.c index 04af5a7bb..8ad6c437d 100644 --- a/src/cmd/gc/gen.c +++ b/src/cmd/gc/gen.c @@ -64,62 +64,83 @@ allocparams(void) lineno = lno; } +void +clearlabels(void) +{ + Label *l; + + for(l=labellist; l!=L; l=l->link) + l->sym->label = L; + + labellist = L; + lastlabel = L; +} + static void -newlab(int op, Sym *s, Node *stmt) +newlab(int op, Node *nlab, Node *stmt) { Label *lab; + Sym *s; + int32 lno; + + s = nlab->left->sym; + lno = nlab->left->lineno; lab = mal(sizeof(*lab)); - lab->link = labellist; - labellist = lab; + if(lastlabel == nil) + labellist = lab; + else + lastlabel->link = lab; + lastlabel = lab; + lab->lineno = lno; lab->sym = s; lab->op = op; lab->label = pc; lab->stmt = stmt; + if(op == OLABEL) { + if(s->label != L) { + lineno = lno; + yyerror("label %S already defined at %L", s, s->label->lineno); + } else + s->label = lab; + } } void checklabels(void) { - Label *l, *m; + Label *l; Sym *s; + int lno; -// // print the label list -// for(l=labellist; l!=L; l=l->link) { -// print("lab %O %S\n", l->op, l->sym); -// } - + lno = lineno; + + // resolve goto using syms for(l=labellist; l!=L; l=l->link) { - switch(l->op) { - case OLABEL: - // these are definitions - + switch(l->op) { + case OGOTO: s = l->sym; - for(m=labellist; m!=L; m=m->link) { - if(m->sym != s) - continue; - switch(m->op) { - case OLABEL: - // these are definitions - - // look for redefinitions - if(l != m) - yyerror("label %S redefined", s); - break; - case OGOTO: - // these are references - - // patch to definition - patch(m->label, l->label); - m->sym = S; // mark done - break; - } + if(s->label == L) { + lineno = l->lineno; + yyerror("label %S not defined", s); + break; } + s->label->used = 1; + patch(l->label, s->label->label); + break; } } - - // diagnostic for all undefined references - for(l=labellist; l!=L; l=l->link) - if(l->op == OGOTO && l->sym != S) - yyerror("label %S not defined", l->sym); + + // diagnose unused labels + for(l=labellist; l!=L; l=l->link) { + if(l->op == OLABEL && !l->used) { + lineno = l->lineno; + yyerror("label %S defined and not used", l->sym); + } + } + + lineno = lno; } /* @@ -171,7 +192,7 @@ gen(Node *n) // insert no-op so that // L:; for { } // does not treat L as a label for the loop. - if(labellist && labellist->label == p3) + if(lastlabel != L && lastlabel->label == p3) gused(N); break; @@ -180,26 +201,27 @@ gen(Node *n) break; case OLABEL: - newlab(OLABEL, n->left->sym, n->right); + newlab(OLABEL, n, n->right); break; case OGOTO: - newlab(OGOTO, n->left->sym, N); + newlab(OGOTO, n, N); gjmp(P); break; case OBREAK: if(n->left != N) { - for(lab=labellist; lab!=L; lab=lab->link) { - if(lab->sym == n->left->sym) { - if(lab->breakpc == P) - yyerror("invalid break label %S", n->left->sym); - gjmp(lab->breakpc); - goto donebreak; - } - } - if(lab == L) + lab = n->left->sym->label; + if(lab == L) { yyerror("break label not defined: %S", n->left->sym); + break; + } + lab->used = 1; + if(lab->breakpc == P) { + yyerror("invalid break label %S", n->left->sym); + break; + } + gjmp(lab->breakpc); break; } if(breakpc == P) { @@ -207,30 +229,28 @@ gen(Node *n) break; } gjmp(breakpc); - donebreak: break; case OCONTINUE: if(n->left != N) { - for(lab=labellist; lab!=L; lab=lab->link) { - if(lab->sym == n->left->sym) { - if(lab->continpc == P) - yyerror("invalid continue label %S", n->left->sym); - gjmp(lab->continpc); - goto donecont; - } - } - if(lab == L) + lab = n->left->sym->label; + if(lab == L) { yyerror("continue label not defined: %S", n->left->sym); + break; + } + lab->used = 1; + if(lab->continpc == P) { + yyerror("invalid continue label %S", n->left->sym); + break; + } + gjmp(lab->continpc); break; } - if(continpc == P) { yyerror("continue is not in a loop"); break; } gjmp(continpc); - donecont: break; case OFOR: @@ -241,10 +261,11 @@ gen(Node *n) continpc = pc; // define break and continue labels - if((lab = labellist) != L && lab->label == p3 && lab->op == OLABEL && lab->stmt == n) { + if((lab = lastlabel) != L && lab->label == p3 && lab->op == OLABEL && lab->stmt == n) { lab->breakpc = breakpc; lab->continpc = continpc; - } + } else + lab = L; gen(n->nincr); // contin: incr patch(p1, pc); // test: @@ -254,6 +275,10 @@ gen(Node *n) patch(breakpc, pc); // done: continpc = scontin; breakpc = sbreak; + if(lab) { + lab->breakpc = P; + lab->continpc = P; + } break; case OIF: @@ -274,13 +299,17 @@ gen(Node *n) breakpc = gjmp(P); // break: goto done // define break label - if((lab = labellist) != L && lab->label == p3 && lab->op == OLABEL && lab->stmt == n) + if((lab = lastlabel) != L && lab->label == p3 && lab->op == OLABEL && lab->stmt == n) lab->breakpc = breakpc; + else + lab = L; patch(p1, pc); // test: genlist(n->nbody); // switch(test) body patch(breakpc, pc); // done: breakpc = sbreak; + if(lab != L) + lab->breakpc = P; break; case OSELECT: @@ -289,13 +318,17 @@ gen(Node *n) breakpc = gjmp(P); // break: goto done // define break label - if((lab = labellist) != L && lab->label == p3 && lab->op == OLABEL && lab->stmt == n) + if((lab = lastlabel) != L && lab->label == p3 && lab->op == OLABEL && lab->stmt == n) lab->breakpc = breakpc; + else + lab = L; patch(p1, pc); // test: genlist(n->nbody); // select() body patch(breakpc, pc); // done: breakpc = sbreak; + if(lab != L) + lab->breakpc = P; break; case OASOP: diff --git a/src/cmd/gc/go.h b/src/cmd/gc/go.h index bf84c12a1..bb258a193 100644 --- a/src/cmd/gc/go.h +++ b/src/cmd/gc/go.h @@ -138,6 +138,7 @@ typedef struct Sym Sym; typedef struct Node Node; typedef struct NodeList NodeList; typedef struct Type Type; +typedef struct Label Label; struct Type { @@ -302,11 +303,14 @@ struct Sym Pkg* pkg; char* name; // variable name Node* def; // definition: ONAME OTYPE OPACK or OLITERAL + Label* label; // corresponding label (ephemeral) int32 block; // blocknumber to catch redeclaration int32 lastlineno; // last declaration for diagnostic }; #define S ((Sym*)0) +EXTERN Sym* dclstack; + struct Pkg { char* name; @@ -356,12 +360,11 @@ enum OARRAY, OARRAYBYTESTR, OARRAYRUNESTR, OSTRARRAYBYTE, OSTRARRAYRUNE, - OAS, OAS2, OAS2MAPW, OAS2FUNC, OAS2RECVCLOSED, OAS2MAPR, OAS2DOTTYPE, OASOP, + OAS, OAS2, OAS2MAPW, OAS2FUNC, OAS2RECV, OAS2MAPR, OAS2DOTTYPE, OASOP, OBAD, OCALL, OCALLFUNC, OCALLMETH, OCALLINTER, OCAP, OCLOSE, - OCLOSED, OCLOSURE, OCMPIFACE, OCMPSTR, OCOMPLIT, OMAPLIT, OSTRUCTLIT, OARRAYLIT, @@ -389,6 +392,7 @@ enum ORECV, ORUNESTR, OSELRECV, + OSELRECV2, OIOTA, OREAL, OIMAG, OCOMPLEX, @@ -441,27 +445,28 @@ enum TCOMPLEX64, // 12 TCOMPLEX128, - TFLOAT32, // 15 + TFLOAT32, // 14 TFLOAT64, - TBOOL, // 18 + TBOOL, // 16 - TPTR32, TPTR64, // 19 + TPTR32, TPTR64, // 17 - TFUNC, // 21 + TFUNC, // 19 TARRAY, T_old_DARRAY, - TSTRUCT, // 24 + TSTRUCT, // 22 TCHAN, TMAP, - TINTER, // 27 + TINTER, // 25 TFORW, TFIELD, TANY, TSTRING, + TUNSAFEPTR, // pseudo-types for literals - TIDEAL, // 32 + TIDEAL, // 31 TNIL, TBLANK, @@ -618,20 +623,22 @@ struct Magic typedef struct Prog Prog; -typedef struct Label Label; struct Label { uchar op; // OGOTO/OLABEL + uchar used; Sym* sym; Node* stmt; Prog* label; // pointer to code Prog* breakpc; // pointer to code Prog* continpc; // pointer to code Label* link; + int32 lineno; }; #define L ((Label*)0) EXTERN Label* labellist; +EXTERN Label* lastlabel; /* * note this is the runtime representation @@ -899,6 +906,7 @@ void allocparams(void); void cgen_as(Node *nl, Node *nr); void cgen_callmeth(Node *n, int proc); void checklabels(void); +void clearlabels(void); int dotoffset(Node *n, int *oary, Node **nn); void gen(Node *n); void genlist(NodeList *l); @@ -993,8 +1001,10 @@ int duint32(Sym *s, int off, uint32 v); int duint64(Sym *s, int off, uint64 v); int duint8(Sym *s, int off, uint8 v); int duintptr(Sym *s, int off, uint64 v); +int dsname(Sym *s, int off, char *dat, int ndat); void dumpobj(void); void ieeedtod(uint64 *ieee, double native); +Sym* stringsym(char*, int); /* * print.c @@ -1229,3 +1239,5 @@ void patch(Prog*, Prog*); void zfile(Biobuf *b, char *p, int n); void zhist(Biobuf *b, int line, vlong offset); void zname(Biobuf *b, Sym *s, int t); +void data(void); +void text(void); diff --git a/src/cmd/gc/go.y b/src/cmd/gc/go.y index 4b838a491..89899ae1e 100644 --- a/src/cmd/gc/go.y +++ b/src/cmd/gc/go.y @@ -461,23 +461,32 @@ case: } break; } -| LCASE expr '=' expr ':' +| LCASE expr_or_type_list '=' expr ':' { + Node *n; + // will be converted to OCASE // right will point to next case // done in casebody() poptodcl(); $$ = nod(OXCASE, N, N); - $$->list = list1(nod(OAS, $2, $4)); + if($2->next == nil) + n = nod(OAS, $2->n, $4); + else { + n = nod(OAS2, N, N); + n->list = $2; + n->rlist = list1($4); + } + $$->list = list1(n); } -| LCASE name LCOLAS expr ':' +| LCASE expr_or_type_list LCOLAS expr ':' { // will be converted to OCASE // right will point to next case // done in casebody() poptodcl(); $$ = nod(OXCASE, N, N); - $$->list = list1(colas(list1($2), list1($4))); + $$->list = list1(colas($2, list1($4))); } | LDEFAULT ':' { @@ -1230,9 +1239,10 @@ fnlitdcl: } fnliteral: - fnlitdcl '{' stmt_list '}' + fnlitdcl lbrace stmt_list '}' { $$ = closurebody($3); + fixlbrace($2); } diff --git a/src/cmd/gc/lex.c b/src/cmd/gc/lex.c index e79d3b0f8..bfd96274e 100644 --- a/src/cmd/gc/lex.c +++ b/src/cmd/gc/lex.c @@ -124,9 +124,6 @@ main(int argc, char *argv[]) runtimepkg = mkpkg(strlit("runtime")); runtimepkg->name = "runtime"; - stringpkg = mkpkg(strlit("string")); - stringpkg->name = "string"; - typepkg = mkpkg(strlit("type")); typepkg->name = "type"; @@ -1555,7 +1552,6 @@ static struct "append", LNAME, Txxx, OAPPEND, "cap", LNAME, Txxx, OCAP, "close", LNAME, Txxx, OCLOSE, - "closed", LNAME, Txxx, OCLOSED, "complex", LNAME, Txxx, OCOMPLEX, "copy", LNAME, Txxx, OCOPY, "imag", LNAME, Txxx, OIMAG, diff --git a/src/cmd/gc/obj.c b/src/cmd/gc/obj.c index fbabe0d43..9f4b7b318 100644 --- a/src/cmd/gc/obj.c +++ b/src/cmd/gc/obj.c @@ -235,3 +235,57 @@ duintptr(Sym *s, int off, uint64 v) { return duintxx(s, off, v, widthptr); } + +Sym* +stringsym(char *s, int len) +{ + static int gen; + Sym *sym; + int off, n, m; + struct { + Strlit lit; + char buf[110]; + } tmp; + Pkg *pkg; + + if(len > 100) { + // huge strings are made static to avoid long names + snprint(namebuf, sizeof(namebuf), ".gostring.%d", ++gen); + pkg = localpkg; + } else { + // small strings get named by their contents, + // so that multiple modules using the same string + // can share it. + tmp.lit.len = len; + memmove(tmp.lit.s, s, len); + tmp.lit.s[len] = '\0'; + snprint(namebuf, sizeof(namebuf), "\"%Z\"", &tmp); + pkg = gostringpkg; + } + sym = pkglookup(namebuf, pkg); + + // SymUniq flag indicates that data is generated already + if(sym->flags & SymUniq) + return sym; + sym->flags |= SymUniq; + + data(); + off = 0; + + // string header + off = dsymptr(sym, off, sym, widthptr+4); + off = duint32(sym, off, len); + + // string data + for(n=0; n<len; n+=m) { + m = 8; + if(m > len-n) + m = len-n; + off = dsname(sym, off, s+n, m); + } + off = duint8(sym, off, 0); // terminating NUL for runtime + ggloblsym(sym, off, 1); + text(); + + return sym; +} diff --git a/src/cmd/gc/print.c b/src/cmd/gc/print.c index 695a5a397..fee37f6d0 100644 --- a/src/cmd/gc/print.c +++ b/src/cmd/gc/print.c @@ -52,7 +52,6 @@ exprfmt(Fmt *f, Node *n, int prec) case OARRAYBYTESTR: case OCAP: case OCLOSE: - case OCLOSED: case OCOPY: case OLEN: case OMAKE: @@ -405,7 +404,6 @@ exprfmt(Fmt *f, Node *n, int prec) case OAPPEND: case OCAP: case OCLOSE: - case OCLOSED: case OLEN: case OCOPY: case OMAKE: diff --git a/src/cmd/gc/range.c b/src/cmd/gc/range.c index 4ee8f39a7..dfb2b8efd 100644 --- a/src/cmd/gc/range.c +++ b/src/cmd/gc/range.c @@ -203,8 +203,8 @@ walkrange(Node *n) hb = nod(OXXX, N, N); tempname(hb, types[TBOOL]); - n->ntest = nod(ONOT, hb, N); - a = nod(OAS2RECVCLOSED, N, N); + n->ntest = nod(ONE, hb, nodbool(0)); + a = nod(OAS2RECV, N, N); a->typecheck = 1; a->list = list(list1(hv1), hb); a->rlist = list1(nod(ORECV, ha, N)); diff --git a/src/cmd/gc/reflect.c b/src/cmd/gc/reflect.c index 8129bf1ce..b98e820c6 100644 --- a/src/cmd/gc/reflect.c +++ b/src/cmd/gc/reflect.c @@ -348,17 +348,19 @@ dimportpath(Pkg *p) * uncommonType * ../../pkg/runtime/type.go:/uncommonType */ -static Sym* -dextratype(Type *t) +static int +dextratype(Sym *sym, int off, Type *t, int ptroff) { int ot, n; - char *p; Sym *s; Sig *a, *m; m = methods(t); if(t->sym == nil && m == nil) - return nil; + return off; + + // fill in *extraType pointer in header + dsymptr(sym, ptroff, sym, off); n = 0; for(a=m; a; a=a->link) { @@ -366,9 +368,8 @@ dextratype(Type *t) n++; } - p = smprint("_.%#T", t); - s = pkglookup(p, typepkg); - ot = 0; + ot = off; + s = sym; if(t->sym) { ot = dgostringptr(s, ot, t->sym->name); if(t != types[t->etype]) @@ -402,9 +403,8 @@ dextratype(Type *t) else ot = duintptr(s, ot, 0); } - ggloblsym(s, ot, 0); - return s; + return ot; } enum { @@ -466,6 +466,7 @@ kinds[] = [TFUNC] = KindFunc, [TCOMPLEX64] = KindComplex64, [TCOMPLEX128] = KindComplex128, + [TUNSAFEPTR] = KindUnsafePointer, }; static char* @@ -488,6 +489,7 @@ structnames[] = [TFLOAT64] = "*runtime.FloatType", [TBOOL] = "*runtime.BoolType", [TSTRING] = "*runtime.StringType", + [TUNSAFEPTR] = "*runtime.UnsafePointerType", [TPTR32] = "*runtime.PtrType", [TPTR64] = "*runtime.PtrType", @@ -514,9 +516,6 @@ typestruct(Type *t) if(isslice(t)) name = "*runtime.SliceType"; - if(isptr[et] && t->type->etype == TANY) - name = "*runtime.UnsafePointerType"; - return pkglookup(name, typepkg); } @@ -553,6 +552,7 @@ haspointers(Type *t) case TSTRING: case TPTR32: case TPTR64: + case TUNSAFEPTR: case TINTER: case TCHAN: case TMAP: @@ -570,7 +570,6 @@ static int dcommontype(Sym *s, int ot, Type *t) { int i; - Sym *s1; Sym *sptr; char *p; @@ -582,8 +581,6 @@ dcommontype(Sym *s, int ot, Type *t) else sptr = weaktypesym(ptrto(t)); - s1 = dextratype(t); - // empty interface pointing at this type. // all the references that we emit are *interface{}; // they point here. @@ -612,8 +609,6 @@ dcommontype(Sym *s, int ot, Type *t) i = kinds[t->etype]; if(t->etype == TARRAY && t->bound < 0) i = KindSlice; - if(isptr[t->etype] && t->type->etype == TANY) - i = KindUnsafePointer; if(!haspointers(t)) i |= KindNoPointers; ot = duint8(s, ot, i); // kind @@ -622,11 +617,14 @@ dcommontype(Sym *s, int ot, Type *t) longsymnames = 0; ot = dgostringptr(s, ot, p); // string free(p); - if(s1) - ot = dsymptr(s, ot, s1, 0); // extraType - else - ot = duintptr(s, ot, 0); - ot = dsymptr(s, ot, sptr, 0); // ptr to type + + // skip pointer to extraType, + // which follows the rest of this type structure. + // caller will fill in if needed. + // otherwise linker will assume 0. + ot += widthptr; + + ot = dsymptr(s, ot, sptr, 0); // ptrto type return ot; } @@ -693,7 +691,7 @@ weaktypesym(Type *t) static Sym* dtypesym(Type *t) { - int ot, n, isddd, dupok; + int ot, xt, n, isddd, dupok; Sym *s, *s1, *s2; Sig *a, *m; Type *t1, *tbase; @@ -714,12 +712,8 @@ dtypesym(Type *t) tbase = t->type; dupok = tbase->sym == S; - if(compiling_runtime) { - if(tbase == types[tbase->etype]) // int, float, etc - goto ok; - if(tbase->etype == tptr && tbase->type->etype == TANY) // unsafe.Pointer - goto ok; - } + if(compiling_runtime && tbase == types[tbase->etype]) // int, float, etc + goto ok; // named types from other files are defined only by those files if(tbase->sym && !tbase->local) @@ -729,15 +723,18 @@ dtypesym(Type *t) ok: ot = 0; + xt = 0; switch(t->etype) { default: ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; break; case TARRAY: // ../../pkg/runtime/type.go:/ArrayType s1 = dtypesym(t->type); ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = dsymptr(s, ot, s1, 0); if(t->bound < 0) ot = duintptr(s, ot, -1); @@ -749,6 +746,7 @@ ok: // ../../pkg/runtime/type.go:/ChanType s1 = dtypesym(t->type); ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = dsymptr(s, ot, s1, 0); ot = duintptr(s, ot, t->chan); break; @@ -765,6 +763,7 @@ ok: dtypesym(t1->type); ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = duint8(s, ot, isddd); // two slice headers: in and out. @@ -796,6 +795,7 @@ ok: // ../../pkg/runtime/type.go:/InterfaceType ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = dsymptr(s, ot, s, ot+widthptr+2*4); ot = duint32(s, ot, n); ot = duint32(s, ot, n); @@ -812,6 +812,7 @@ ok: s1 = dtypesym(t->down); s2 = dtypesym(t->type); ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = dsymptr(s, ot, s1, 0); ot = dsymptr(s, ot, s2, 0); break; @@ -826,6 +827,7 @@ ok: // ../../pkg/runtime/type.go:/PtrType s1 = dtypesym(t->type); ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = dsymptr(s, ot, s1, 0); break; @@ -838,6 +840,7 @@ ok: n++; } ot = dcommontype(s, ot, t); + xt = ot - 2*widthptr; ot = dsymptr(s, ot, s, ot+widthptr+2*4); ot = duint32(s, ot, n); ot = duint32(s, ot, n); @@ -859,7 +862,7 @@ ok: } break; } - + ot = dextratype(s, ot, t, xt); ggloblsym(s, ot, dupok); return s; } @@ -908,7 +911,7 @@ dumptypestructs(void) for(i=1; i<=TBOOL; i++) dtypesym(ptrto(types[i])); dtypesym(ptrto(types[TSTRING])); - dtypesym(ptrto(pkglookup("Pointer", unsafepkg)->def->type)); + dtypesym(ptrto(types[TUNSAFEPTR])); // add paths for runtime and main, which 6l imports implicitly. dimportpath(runtimepkg); diff --git a/src/cmd/gc/runtime.go b/src/cmd/gc/runtime.go index bf7d045c0..35d11eca9 100644 --- a/src/cmd/gc/runtime.go +++ b/src/cmd/gc/runtime.go @@ -92,17 +92,19 @@ func mapiter2(hiter *any) (key any, val any) // *byte is really *runtime.Type func makechan(elem *byte, hint int64) (hchan chan any) func chanrecv1(hchan <-chan any) (elem any) -func chanrecv3(hchan <-chan any) (elem any, closed bool) +func chanrecv2(hchan <-chan any) (elem any, received bool) func chansend1(hchan chan<- any, elem any) func closechan(hchan any) func closedchan(hchan any) bool func selectnbsend(hchan chan<- any, elem any) bool func selectnbrecv(elem *any, hchan <-chan any) bool +func selectnbrecv2(elem *any, received *bool, hchan <-chan any) bool func newselect(size int) (sel *byte) func selectsend(sel *byte, hchan chan<- any, elem any) (selected bool) func selectrecv(sel *byte, hchan <-chan any, elem *any) (selected bool) +func selectrecv2(sel *byte, hchan <-chan any, elem *any, received *bool) (selected bool) func selectdefault(sel *byte) (selected bool) func selectgo(sel *byte) func block() diff --git a/src/cmd/gc/select.c b/src/cmd/gc/select.c index 58a147745..91d4ebfd5 100644 --- a/src/cmd/gc/select.c +++ b/src/cmd/gc/select.c @@ -58,6 +58,18 @@ typecheckselect(Node *sel) n->op = OSELRECV; break; + case OAS2RECV: + // convert x, ok = <-c into OSELRECV(x, <-c) with ntest=ok + if(n->right->op != ORECV) { + yyerror("select assignment must have receive on right hand side"); + break; + } + n->op = OSELRECV2; + n->left = n->list->n; + n->ntest = n->list->next->n; + n->right = n->rlist->n; + break; + case ORECV: // convert <-c into OSELRECV(N, <-c) n = nod(OSELRECV, N, n); @@ -122,6 +134,18 @@ walkselect(Node *sel) typecheck(&n, Etop); } break; + + case OSELRECV2: + r = n->right; + ch = cheapexpr(r->left, &l); + r->left = ch; + + a = nod(OAS2, N, N); + a->list = n->list; + a->rlist = n->rlist; + n = a; + typecheck(&n, Etop); + break; } // if ch == nil { block() }; n; @@ -146,6 +170,7 @@ walkselect(Node *sel) continue; switch(n->op) { case OSELRECV: + case OSELRECV2: ch = n->right->left; // If we can use the address of the target without @@ -154,6 +179,28 @@ walkselect(Node *sel) // Also introduce a temporary for := variables that escape, // so that we can delay the heap allocation until the case // is selected. + if(n->op == OSELRECV2) { + if(n->ntest == N || isblank(n->ntest)) + n->ntest = nodnil(); + else if(n->ntest->op == ONAME && + (!n->colas || (n->ntest->class&PHEAP) == 0) && + convertop(types[TBOOL], n->ntest->type, nil) == OCONVNOP) { + n->ntest = nod(OADDR, n->ntest, N); + n->ntest->etype = 1; // pointer does not escape + typecheck(&n->ntest, Erv); + } else { + tmp = nod(OXXX, N, N); + tempname(tmp, types[TBOOL]); + a = nod(OADDR, tmp, N); + a->etype = 1; // pointer does not escape + typecheck(&a, Erv); + r = nod(OAS, n->ntest, tmp); + typecheck(&r, Etop); + cas->nbody = concat(list1(r), cas->nbody); + n->ntest = a; + } + } + if(n->left == N || isblank(n->left)) n->left = nodnil(); else if(n->left->op == ONAME && @@ -171,10 +218,12 @@ walkselect(Node *sel) r = nod(OAS, n->left, tmp); typecheck(&r, Etop); cas->nbody = concat(list1(r), cas->nbody); - cas->nbody = concat(n->ninit, cas->nbody); - n->ninit = nil; n->left = a; } + + cas->nbody = concat(n->ninit, cas->nbody); + n->ninit = nil; + break; } } @@ -212,6 +261,16 @@ walkselect(Node *sel) mkcall1(chanfn("selectnbrecv", 2, ch->type), types[TBOOL], &r->ninit, n->left, ch)); break; + + case OSELRECV2: + // if c != nil && selectnbrecv2(&v, c) { body } else { default body } + r = nod(OIF, N, N); + r->ninit = cas->ninit; + ch = cheapexpr(n->right->left, &r->ninit); + r->ntest = nod(OANDAND, nod(ONE, ch, nodnil()), + mkcall1(chanfn("selectnbrecv2", 2, ch->type), + types[TBOOL], &r->ninit, n->left, n->ntest, ch)); + break; } typecheck(&r->ntest, Erv); r->nbody = cas->nbody; @@ -254,11 +313,18 @@ walkselect(Node *sel) r->ntest = mkcall1(chanfn("selectsend", 2, n->left->type), types[TBOOL], &init, var, n->left, n->right); break; + case OSELRECV: // selectrecv(sel *byte, hchan *chan any, elem *any) (selected bool); r->ntest = mkcall1(chanfn("selectrecv", 2, n->right->left->type), types[TBOOL], &init, var, n->right->left, n->left); break; + + case OSELRECV2: + // selectrecv2(sel *byte, hchan *chan any, elem *any, received *bool) (selected bool); + r->ntest = mkcall1(chanfn("selectrecv2", 2, n->right->left->type), types[TBOOL], + &init, var, n->right->left, n->left, n->ntest); + break; } } r->nbody = concat(r->nbody, cas->nbody); diff --git a/src/cmd/gc/sinit.c b/src/cmd/gc/sinit.c index 31781646d..be96a1477 100644 --- a/src/cmd/gc/sinit.c +++ b/src/cmd/gc/sinit.c @@ -94,7 +94,7 @@ init1(Node *n, NodeList **out) case OAS2FUNC: case OAS2MAPR: case OAS2DOTTYPE: - case OAS2RECVCLOSED: + case OAS2RECV: if(n->defn->initorder) break; n->defn->initorder = 1; diff --git a/src/cmd/gc/subr.c b/src/cmd/gc/subr.c index 142e5ba41..2098794a7 100644 --- a/src/cmd/gc/subr.c +++ b/src/cmd/gc/subr.c @@ -135,6 +135,7 @@ yyerror(char *fmt, ...) int i; static int lastsyntax; va_list arg; + char buf[512], *p; if(strncmp(fmt, "syntax error", 12) == 0) { nsyntaxerrors++; @@ -147,6 +148,16 @@ yyerror(char *fmt, ...) return; lastsyntax = lexlineno; + if(strstr(fmt, "{ or {")) { + // The grammar has { and LBRACE but both show up as {. + // Rewrite syntax error referring to "{ or {" to say just "{". + strecpy(buf, buf+sizeof buf, fmt); + p = strstr(buf, "{ or {"); + if(p) + memmove(p+1, p+6, strlen(p+6)+1); + fmt = buf; + } + // look for parse state-specific errors in list (see go.errors). for(i=0; i<nelem(yymsg); i++) { if(yymsg[i].yystate == yystate && yymsg[i].yychar == yychar) { @@ -834,7 +845,6 @@ goopnames[] = [OCALL] = "function call", [OCAP] = "cap", [OCASE] = "case", - [OCLOSED] = "closed", [OCLOSE] = "close", [OCOMPLEX] = "complex", [OCOM] = "^", @@ -1144,7 +1154,7 @@ Tpretty(Fmt *fp, Type *t) && t->sym != S && !(fp->flags&FmtLong)) { s = t->sym; - if(t == types[t->etype]) + if(t == types[t->etype] && t->etype != TUNSAFEPTR) return fmtprint(fp, "%s", s->name); if(exporting) { if(fp->flags & FmtShort) @@ -1304,6 +1314,11 @@ Tpretty(Fmt *fp, Type *t) if(t->sym) return fmtprint(fp, "undefined %S", t->sym); return fmtprint(fp, "undefined"); + + case TUNSAFEPTR: + if(exporting) + return fmtprint(fp, "\"unsafe\".Pointer"); + return fmtprint(fp, "unsafe.Pointer"); } // Don't know how to handle - fall back to detailed prints. @@ -1346,6 +1361,9 @@ Tconv(Fmt *fp) } } + if(sharp || exporting) + fatal("missing %E case during export", t->etype); + et = t->etype; fmtprint(fp, "%E ", et); if(t->sym != S) @@ -1663,6 +1681,9 @@ isselect(Node *n) s = pkglookup("selectrecv", runtimepkg); if(s == n->sym) return 1; + s = pkglookup("selectrecv2", runtimepkg); + if(s == n->sym) + return 1; s = pkglookup("selectdefault", runtimepkg); if(s == n->sym) return 1; @@ -1864,7 +1885,7 @@ assignop(Type *src, Type *dst, char **why) if(why != nil) *why = ""; - if(safemode && (isptrto(src, TANY) || isptrto(dst, TANY))) { + if(safemode && src != T && src->etype == TUNSAFEPTR) { yyerror("cannot use unsafe.Pointer"); errorexit(); } @@ -1879,8 +1900,9 @@ assignop(Type *src, Type *dst, char **why) return OCONVNOP; // 2. src and dst have identical underlying types - // and either src or dst is not a named type. - if(eqtype(src->orig, dst->orig) && (src->sym == S || dst->sym == S)) + // and either src or dst is not a named type or + // both are interface types. + if(eqtype(src->orig, dst->orig) && (src->sym == S || dst->sym == S || src->etype == TINTER)) return OCONVNOP; // 3. dst is an interface type and src implements dst. @@ -2028,11 +2050,11 @@ convertop(Type *src, Type *dst, char **why) } // 8. src is a pointer or uintptr and dst is unsafe.Pointer. - if((isptr[src->etype] || src->etype == TUINTPTR) && isptrto(dst, TANY)) + if((isptr[src->etype] || src->etype == TUINTPTR) && dst->etype == TUNSAFEPTR) return OCONVNOP; // 9. src is unsafe.Pointer and dst is a pointer or uintptr. - if(isptrto(src, TANY) && (isptr[dst->etype] || dst->etype == TUINTPTR)) + if(src->etype == TUNSAFEPTR && (isptr[dst->etype] || dst->etype == TUINTPTR)) return OCONVNOP; return 0; @@ -2043,13 +2065,16 @@ Node* assignconv(Node *n, Type *t, char *context) { int op; - Node *r; + Node *r, *old; char *why; if(n == N || n->type == T) return n; + old = n; + old->diag++; // silence errors about n; we'll issue one below defaultlit(&n, t); + old->diag--; if(t->etype == TBLANK) return n; diff --git a/src/cmd/gc/typecheck.c b/src/cmd/gc/typecheck.c index 3e8f35877..1cc5abd5c 100644 --- a/src/cmd/gc/typecheck.c +++ b/src/cmd/gc/typecheck.c @@ -318,7 +318,7 @@ reswitch: n->left = N; goto ret; } - if(!isptr[t->etype] || (t->type != T && t->type->etype == TANY) /* unsafe.Pointer */) { + if(!isptr[t->etype]) { yyerror("invalid indirect of %+N", n->left); goto error; } @@ -921,7 +921,6 @@ reswitch: n->type = t; goto ret; - case OCLOSED: case OCLOSE: if(onearg(n, "%#O", n->op) < 0) goto error; @@ -934,11 +933,7 @@ reswitch: yyerror("invalid operation: %#N (non-chan type %T)", n, t); goto error; } - if(n->op == OCLOSED) { - n->type = types[TBOOL]; - ok |= Erv; - } else - ok |= Etop; + ok |= Etop; goto ret; case OAPPEND: @@ -1316,7 +1311,7 @@ ret: // TODO(rsc): should not need to check importpkg, // but reflect mentions unsafe.Pointer. - if(safemode && !incannedimport && !importpkg && isptrto(t, TANY)) + if(safemode && !incannedimport && !importpkg && t && t->etype == TUNSAFEPTR) yyerror("cannot use unsafe.Pointer"); evconst(n); @@ -1639,11 +1634,6 @@ typecheckaste(int op, Node *call, int isddd, Type *tstruct, NodeList *nl, char * for(tl=tstruct->type; tl; tl=tl->down) { t = tl->type; if(tl->isddd) { - if(nl != nil && nl->n->op == ONAME && nl->n->isddd && !isddd) { - // TODO(rsc): This is not actually illegal, but it will help catch bugs. - yyerror("to pass '%#N' as ...%T, use '%#N...'", nl->n, t->type, nl->n); - isddd = 1; - } if(isddd) { if(nl == nil) goto notenough; @@ -2377,8 +2367,9 @@ typecheckas2(Node *n) n->op = OAS2MAPR; goto common; case ORECV: - yyerror("cannot use multiple-value assignment for non-blocking receive; use select"); - goto out; + n->op = OAS2RECV; + n->right = n->rlist->n; + goto common; case ODOTTYPE: n->op = OAS2DOTTYPE; r->op = ODOTTYPE2; diff --git a/src/cmd/gc/unsafe.go b/src/cmd/gc/unsafe.go index bd7b7771a..b2a341d39 100644 --- a/src/cmd/gc/unsafe.go +++ b/src/cmd/gc/unsafe.go @@ -8,7 +8,7 @@ package PACKAGE -type Pointer *any +type Pointer uintptr // not really; filled in by compiler func Offsetof(any) int func Sizeof(any) int diff --git a/src/cmd/gc/walk.c b/src/cmd/gc/walk.c index b32b6fff5..b8c6842e0 100644 --- a/src/cmd/gc/walk.c +++ b/src/cmd/gc/walk.c @@ -403,12 +403,11 @@ walkstmt(Node **np) case OAS: case OAS2: case OAS2DOTTYPE: - case OAS2RECVCLOSED: + case OAS2RECV: case OAS2FUNC: case OAS2MAPW: case OAS2MAPR: case OCLOSE: - case OCLOSED: case OCOPY: case OCALLMETH: case OCALLINTER: @@ -822,14 +821,13 @@ walkexpr(Node **np, NodeList **init) n = liststmt(concat(concat(list1(r), ll), lpost)); goto ret; - case OAS2RECVCLOSED: - // a = <-c; b = closed(c) but atomic + case OAS2RECV: *init = concat(*init, n->ninit); n->ninit = nil; r = n->rlist->n; walkexprlistsafe(n->list, init); walkexpr(&r->left, init); - fn = chanfn("chanrecv3", 2, r->left->type); + fn = chanfn("chanrecv2", 2, r->left->type); r = mkcall1(fn, getoutargx(fn->type), init, r->left); n->rlist->n = r; n->op = OAS2FUNC; @@ -1309,13 +1307,6 @@ walkexpr(Node **np, NodeList **init) n = mkcall1(fn, T, init, n->left); goto ret; - case OCLOSED: - // cannot use chanfn - closechan takes any, not chan any - fn = syslook("closedchan", 1); - argtype(fn, n->left->type); - n = mkcall1(fn, n->type, init, n->left); - goto ret; - case OMAKECHAN: n = mkcall1(chanfn("makechan", 1, n->type), n->type, init, typename(n->type->type), diff --git a/src/cmd/godefs/stabs.c b/src/cmd/godefs/stabs.c index f2bb57eb6..30a05fc70 100644 --- a/src/cmd/godefs/stabs.c +++ b/src/cmd/godefs/stabs.c @@ -219,7 +219,7 @@ parsedef(char **pp, char *name) t = emalloc(sizeof *t); switch(*p) { default: - fprint(2, "unknown type char %c\n", *p); + fprint(2, "unknown type char %c in %s\n", *p, p); *pp = ""; return t; @@ -284,6 +284,7 @@ parsedef(char **pp, char *name) return nil; break; + case 'B': // volatile case 'k': // const ++*pp; return parsedef(pp, nil); diff --git a/src/cmd/godoc/codewalk.go b/src/cmd/godoc/codewalk.go index bda63a9f6..24087eb88 100644 --- a/src/cmd/godoc/codewalk.go +++ b/src/cmd/godoc/codewalk.go @@ -115,7 +115,7 @@ func (st *Codestep) String() string { // loadCodewalk reads a codewalk from the named XML file. func loadCodewalk(file string) (*Codewalk, os.Error) { - f, err := os.Open(file, os.O_RDONLY, 0) + f, err := os.Open(file) if err != nil { return nil, err } diff --git a/src/cmd/godoc/dirtrees.go b/src/cmd/godoc/dirtrees.go index 3ad7c8cfc..97737ca5a 100644 --- a/src/cmd/godoc/dirtrees.go +++ b/src/cmd/godoc/dirtrees.go @@ -266,8 +266,8 @@ func (dir *Directory) lookupLocal(name string) *Directory { // lookup looks for the *Directory for a given path, relative to dir. func (dir *Directory) lookup(path string) *Directory { - d := strings.Split(dir.Path, "/", -1) - p := strings.Split(path, "/", -1) + d := strings.Split(dir.Path, string(filepath.Separator), -1) + p := strings.Split(path, string(filepath.Separator), -1) i := 0 for i < len(d) { if i >= len(p) || d[i] != p[i] { @@ -342,8 +342,8 @@ func (root *Directory) listing(skipRoot bool) *DirList { if strings.HasPrefix(d.Path, root.Path) { path = d.Path[len(root.Path):] } - // remove trailing '/' if any - path must be relative - if len(path) > 0 && path[0] == '/' { + // remove trailing separator if any - path must be relative + if len(path) > 0 && path[0] == filepath.Separator { path = path[1:] } p.Path = path diff --git a/src/cmd/godoc/format.go b/src/cmd/godoc/format.go index da1466b21..7e6470846 100644 --- a/src/cmd/godoc/format.go +++ b/src/cmd/godoc/format.go @@ -292,7 +292,7 @@ func rangeSelection(str string) Selection { from, _ := strconv.Atoi(m[1]) to, _ := strconv.Atoi(m[2]) if from < to { - return makeSelection([][]int{[]int{from, to}}) + return makeSelection([][]int{{from, to}}) } } return nil @@ -309,7 +309,7 @@ func rangeSelection(str string) Selection { // var startTags = [][]byte{ /* 000 */ []byte(``), - /* 001 */ []byte(`<span class ="comment">`), + /* 001 */ []byte(`<span class="comment">`), /* 010 */ []byte(`<span class="highlight">`), /* 011 */ []byte(`<span class="highlight-comment">`), /* 100 */ []byte(`<span class="selection">`), diff --git a/src/cmd/godoc/godoc.go b/src/cmd/godoc/godoc.go index 9dce5edf9..b8e9dbc92 100644 --- a/src/cmd/godoc/godoc.go +++ b/src/cmd/godoc/godoc.go @@ -65,6 +65,7 @@ var ( tabwidth = flag.Int("tabwidth", 4, "tab width") showTimestamps = flag.Bool("timestamps", true, "show timestamps with directory listings") maxResults = flag.Int("maxresults", 10000, "maximum number of full text search results shown") + templateDir = flag.String("templates", "", "directory containing alternate template files") // file system mapping fsMap Mapping // user-defined mapping @@ -635,6 +636,14 @@ var fmap = template.FormatterMap{ func readTemplate(name string) *template.Template { path := filepath.Join(*goroot, "lib", "godoc", name) + if *templateDir != "" { + defaultpath := path + path = filepath.Join(*templateDir, name) + if _, err := os.Stat(path); err != nil { + log.Print("readTemplate:", err) + path = defaultpath + } + } data, err := ioutil.ReadFile(path) if err != nil { log.Fatalf("ReadFile %s: %v", path, err) @@ -702,7 +711,7 @@ func servePage(w http.ResponseWriter, title, subtitle, query string, content []b func serveText(w http.ResponseWriter, text []byte) { - w.SetHeader("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Write(text) } diff --git a/src/cmd/godoc/index.go b/src/cmd/godoc/index.go index 5af4d15cb..5938d0b74 100644 --- a/src/cmd/godoc/index.go +++ b/src/cmd/godoc/index.go @@ -624,7 +624,7 @@ func pkgName(filename string) string { // failed (that is, if the file was not added), it returns file == nil. func (x *Indexer) addFile(filename string, goFile bool) (file *token.File, ast *ast.File) { // open file - f, err := os.Open(filename, os.O_RDONLY, 0) + f, err := os.Open(filename) if err != nil { return } diff --git a/src/cmd/godoc/main.go b/src/cmd/godoc/main.go index 1ebb80279..e426626b3 100644 --- a/src/cmd/godoc/main.go +++ b/src/cmd/godoc/main.go @@ -83,7 +83,7 @@ func exec(rw http.ResponseWriter, args []string) (status int) { if *verbose { log.Printf("executing %v", args) } - p, err := os.StartProcess(bin, args, os.Environ(), *goroot, fds) + p, err := os.StartProcess(bin, args, &os.ProcAttr{Files: fds, Dir: *goroot}) defer r.Close() w.Close() if err != nil { @@ -111,7 +111,7 @@ func exec(rw http.ResponseWriter, args []string) (status int) { os.Stderr.Write(buf.Bytes()) } if rw != nil { - rw.SetHeader("content-type", "text/plain; charset=utf-8") + rw.Header().Set("Content-Type", "text/plain; charset=utf-8") rw.Write(buf.Bytes()) } @@ -152,7 +152,7 @@ func usage() { func loggingHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - log.Printf("%s\t%s", w.RemoteAddr(), req.URL) + log.Printf("%s\t%s", req.RemoteAddr, req.URL) h.ServeHTTP(w, req) }) } @@ -222,6 +222,9 @@ func main() { flag.Usage = usage flag.Parse() + // Clean goroot: normalize path separator. + *goroot = filepath.Clean(*goroot) + // Check usage: either server and no args, or command line and args if (*httpAddr != "") != (flag.NArg() == 0) { usage() diff --git a/src/cmd/godoc/spec.go b/src/cmd/godoc/spec.go index a533c1e0a..f8b95e387 100644 --- a/src/cmd/godoc/spec.go +++ b/src/cmd/godoc/spec.go @@ -27,7 +27,7 @@ type ebnfParser struct { prev int // offset of previous token pos token.Pos // token position tok token.Token // one token look-ahead - lit []byte // token literal + lit string // token literal } @@ -63,7 +63,7 @@ func (p *ebnfParser) errorExpected(pos token.Pos, msg string) { // make the error message more specific msg += ", found '" + p.tok.String() + "'" if p.tok.IsLiteral() { - msg += " " + string(p.lit) + msg += " " + p.lit } } p.Error(p.file.Position(pos), msg) @@ -81,7 +81,7 @@ func (p *ebnfParser) expect(tok token.Token) token.Pos { func (p *ebnfParser) parseIdentifier(def bool) { - name := string(p.lit) + name := p.lit p.expect(token.IDENT) if def { fmt.Fprintf(p.out, `<a id="%s">%s</a>`, name, name) diff --git a/src/cmd/godoc/utils.go b/src/cmd/godoc/utils.go index 9517aee7a..593b51ce0 100644 --- a/src/cmd/godoc/utils.go +++ b/src/cmd/godoc/utils.go @@ -155,7 +155,7 @@ func isTextFile(filename string) bool { // the extension is not known; read an initial chunk // of the file and check if it looks like text - f, err := os.Open(filename, os.O_RDONLY, 0) + f, err := os.Open(filename) if err != nil { return false } diff --git a/src/cmd/gofix/Makefile b/src/cmd/gofix/Makefile new file mode 100644 index 000000000..12f09b4e4 --- /dev/null +++ b/src/cmd/gofix/Makefile @@ -0,0 +1,24 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../Make.inc + +TARG=gofix +GOFILES=\ + fix.go\ + netdial.go\ + main.go\ + osopen.go\ + httpserver.go\ + procattr.go\ + reflect.go\ + typecheck.go\ + +include ../../Make.cmd + +test: + gotest + +testshort: + gotest -test.short diff --git a/src/cmd/gofix/doc.go b/src/cmd/gofix/doc.go new file mode 100644 index 000000000..902fe76f2 --- /dev/null +++ b/src/cmd/gofix/doc.go @@ -0,0 +1,34 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Gofix finds Go programs that use old APIs and rewrites them to use +newer ones. After you update to a new Go release, gofix helps make +the necessary changes to your programs. + +Usage: + gofix [-r name,...] [path ...] + +Without an explicit path, gofix reads standard input and writes the +result to standard output. + +If the named path is a file, gofix rewrites the named files in place. +If the named path is a directory, gofix rewrites all .go files in that +directory tree. When gofix rewrites a file, it prints a line to standard +error giving the name of the file and the rewrite applied. + +The -r flag restricts the set of rewrites considered to those in the +named list. By default gofix considers all known rewrites. Gofix's +rewrites are idempotent, so that it is safe to apply gofix to updated +or partially updated code even without using the -r flag. + +Gofix prints the full list of fixes it can apply in its help output; +to see them, run gofix -?. + +Gofix does not make backup copies of the files that it edits. +Instead, use a version control system's ``diff'' functionality to inspect +the changes that gofix makes before committing them. + +*/ +package documentation diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go new file mode 100644 index 000000000..0852ce21e --- /dev/null +++ b/src/cmd/gofix/fix.go @@ -0,0 +1,422 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "os" + "strconv" +) + +type fix struct { + name string + f func(*ast.File) bool + desc string +} + +// main runs sort.Sort(fixes) after init process is done. +type fixlist []fix + +func (f fixlist) Len() int { return len(f) } +func (f fixlist) Swap(i, j int) { f[i], f[j] = f[j], f[i] } +func (f fixlist) Less(i, j int) bool { return f[i].name < f[j].name } + +var fixes fixlist + +func register(f fix) { + fixes = append(fixes, f) +} + +// walk traverses the AST x, calling visit(y) for each node y in the tree but +// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt, +// in a bottom-up traversal. +func walk(x interface{}, visit func(interface{})) { + walkBeforeAfter(x, nop, visit) +} + +func nop(interface{}) {} + +// walkBeforeAfter is like walk but calls before(x) before traversing +// x's children and after(x) afterward. +func walkBeforeAfter(x interface{}, before, after func(interface{})) { + before(x) + + switch n := x.(type) { + default: + panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x)) + + case nil: + + // pointers to interfaces + case *ast.Decl: + walkBeforeAfter(*n, before, after) + case *ast.Expr: + walkBeforeAfter(*n, before, after) + case *ast.Spec: + walkBeforeAfter(*n, before, after) + case *ast.Stmt: + walkBeforeAfter(*n, before, after) + + // pointers to struct pointers + case **ast.BlockStmt: + walkBeforeAfter(*n, before, after) + case **ast.CallExpr: + walkBeforeAfter(*n, before, after) + case **ast.FieldList: + walkBeforeAfter(*n, before, after) + case **ast.FuncType: + walkBeforeAfter(*n, before, after) + + // pointers to slices + case *[]ast.Stmt: + walkBeforeAfter(*n, before, after) + case *[]ast.Expr: + walkBeforeAfter(*n, before, after) + case *[]ast.Decl: + walkBeforeAfter(*n, before, after) + case *[]ast.Spec: + walkBeforeAfter(*n, before, after) + case *[]*ast.File: + walkBeforeAfter(*n, before, after) + + // These are ordered and grouped to match ../../pkg/go/ast/ast.go + case *ast.Field: + walkBeforeAfter(&n.Type, before, after) + case *ast.FieldList: + for _, field := range n.List { + walkBeforeAfter(field, before, after) + } + case *ast.BadExpr: + case *ast.Ident: + case *ast.Ellipsis: + case *ast.BasicLit: + case *ast.FuncLit: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.CompositeLit: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Elts, before, after) + case *ast.ParenExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.SelectorExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.IndexExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Index, before, after) + case *ast.SliceExpr: + walkBeforeAfter(&n.X, before, after) + if n.Low != nil { + walkBeforeAfter(&n.Low, before, after) + } + if n.High != nil { + walkBeforeAfter(&n.High, before, after) + } + case *ast.TypeAssertExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Type, before, after) + case *ast.CallExpr: + walkBeforeAfter(&n.Fun, before, after) + walkBeforeAfter(&n.Args, before, after) + case *ast.StarExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.UnaryExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.BinaryExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Y, before, after) + case *ast.KeyValueExpr: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + + case *ast.ArrayType: + walkBeforeAfter(&n.Len, before, after) + walkBeforeAfter(&n.Elt, before, after) + case *ast.StructType: + walkBeforeAfter(&n.Fields, before, after) + case *ast.FuncType: + walkBeforeAfter(&n.Params, before, after) + if n.Results != nil { + walkBeforeAfter(&n.Results, before, after) + } + case *ast.InterfaceType: + walkBeforeAfter(&n.Methods, before, after) + case *ast.MapType: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + case *ast.ChanType: + walkBeforeAfter(&n.Value, before, after) + + case *ast.BadStmt: + case *ast.DeclStmt: + walkBeforeAfter(&n.Decl, before, after) + case *ast.EmptyStmt: + case *ast.LabeledStmt: + walkBeforeAfter(&n.Stmt, before, after) + case *ast.ExprStmt: + walkBeforeAfter(&n.X, before, after) + case *ast.SendStmt: + walkBeforeAfter(&n.Chan, before, after) + walkBeforeAfter(&n.Value, before, after) + case *ast.IncDecStmt: + walkBeforeAfter(&n.X, before, after) + case *ast.AssignStmt: + walkBeforeAfter(&n.Lhs, before, after) + walkBeforeAfter(&n.Rhs, before, after) + case *ast.GoStmt: + walkBeforeAfter(&n.Call, before, after) + case *ast.DeferStmt: + walkBeforeAfter(&n.Call, before, after) + case *ast.ReturnStmt: + walkBeforeAfter(&n.Results, before, after) + case *ast.BranchStmt: + case *ast.BlockStmt: + walkBeforeAfter(&n.List, before, after) + case *ast.IfStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Body, before, after) + walkBeforeAfter(&n.Else, before, after) + case *ast.CaseClause: + walkBeforeAfter(&n.List, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.SwitchStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Tag, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.TypeSwitchStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Assign, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.CommClause: + walkBeforeAfter(&n.Comm, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.SelectStmt: + walkBeforeAfter(&n.Body, before, after) + case *ast.ForStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Post, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.RangeStmt: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Body, before, after) + + case *ast.ImportSpec: + case *ast.ValueSpec: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Values, before, after) + case *ast.TypeSpec: + walkBeforeAfter(&n.Type, before, after) + + case *ast.BadDecl: + case *ast.GenDecl: + walkBeforeAfter(&n.Specs, before, after) + case *ast.FuncDecl: + if n.Recv != nil { + walkBeforeAfter(&n.Recv, before, after) + } + walkBeforeAfter(&n.Type, before, after) + if n.Body != nil { + walkBeforeAfter(&n.Body, before, after) + } + + case *ast.File: + walkBeforeAfter(&n.Decls, before, after) + + case *ast.Package: + walkBeforeAfter(&n.Files, before, after) + + case []*ast.File: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Decl: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Expr: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Stmt: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Spec: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + } + after(x) +} + +// imports returns true if f imports path. +func imports(f *ast.File, path string) bool { + for _, s := range f.Imports { + t, err := strconv.Unquote(s.Path.Value) + if err == nil && t == path { + return true + } + } + return false +} + +// isPkgDot returns true if t is the expression "pkg.name" +// where pkg is an imported identifier. +func isPkgDot(t ast.Expr, pkg, name string) bool { + sel, ok := t.(*ast.SelectorExpr) + return ok && isTopName(sel.X, pkg) && sel.Sel.String() == name +} + +// isPtrPkgDot returns true if f is the expression "*pkg.name" +// where pkg is an imported identifier. +func isPtrPkgDot(t ast.Expr, pkg, name string) bool { + ptr, ok := t.(*ast.StarExpr) + return ok && isPkgDot(ptr.X, pkg, name) +} + +// isTopName returns true if n is a top-level unresolved identifier with the given name. +func isTopName(n ast.Expr, name string) bool { + id, ok := n.(*ast.Ident) + return ok && id.Name == name && id.Obj == nil +} + +// isName returns true if n is an identifier with the given name. +func isName(n ast.Expr, name string) bool { + id, ok := n.(*ast.Ident) + return ok && id.String() == name +} + +// isCall returns true if t is a call to pkg.name. +func isCall(t ast.Expr, pkg, name string) bool { + call, ok := t.(*ast.CallExpr) + return ok && isPkgDot(call.Fun, pkg, name) +} + +// If n is an *ast.Ident, isIdent returns it; otherwise isIdent returns nil. +func isIdent(n interface{}) *ast.Ident { + id, _ := n.(*ast.Ident) + return id +} + +// refersTo returns true if n is a reference to the same object as x. +func refersTo(n ast.Node, x *ast.Ident) bool { + id, ok := n.(*ast.Ident) + // The test of id.Name == x.Name handles top-level unresolved + // identifiers, which all have Obj == nil. + return ok && id.Obj == x.Obj && id.Name == x.Name +} + +// isBlank returns true if n is the blank identifier. +func isBlank(n ast.Expr) bool { + return isName(n, "_") +} + +// isEmptyString returns true if n is an empty string literal. +func isEmptyString(n ast.Expr) bool { + lit, ok := n.(*ast.BasicLit) + return ok && lit.Kind == token.STRING && len(lit.Value) == 2 +} + +func warn(pos token.Pos, msg string, args ...interface{}) { + if pos.IsValid() { + msg = "%s: " + msg + arg1 := []interface{}{fset.Position(pos).String()} + args = append(arg1, args...) + } + fmt.Fprintf(os.Stderr, msg+"\n", args...) +} + +// countUses returns the number of uses of the identifier x in scope. +func countUses(x *ast.Ident, scope []ast.Stmt) int { + count := 0 + ff := func(n interface{}) { + if n, ok := n.(ast.Node); ok && refersTo(n, x) { + count++ + } + } + for _, n := range scope { + walk(n, ff) + } + return count +} + +// rewriteUses replaces all uses of the identifier x and !x in scope +// with f(x.Pos()) and fnot(x.Pos()). +func rewriteUses(x *ast.Ident, f, fnot func(token.Pos) ast.Expr, scope []ast.Stmt) { + var lastF ast.Expr + ff := func(n interface{}) { + ptr, ok := n.(*ast.Expr) + if !ok { + return + } + nn := *ptr + + // The child node was just walked and possibly replaced. + // If it was replaced and this is a negation, replace with fnot(p). + not, ok := nn.(*ast.UnaryExpr) + if ok && not.Op == token.NOT && not.X == lastF { + *ptr = fnot(nn.Pos()) + return + } + if refersTo(nn, x) { + lastF = f(nn.Pos()) + *ptr = lastF + } + } + for _, n := range scope { + walk(n, ff) + } +} + +// assignsTo returns true if any of the code in scope assigns to or takes the address of x. +func assignsTo(x *ast.Ident, scope []ast.Stmt) bool { + assigned := false + ff := func(n interface{}) { + if assigned { + return + } + switch n := n.(type) { + case *ast.UnaryExpr: + // use of &x + if n.Op == token.AND && refersTo(n.X, x) { + assigned = true + return + } + case *ast.AssignStmt: + for _, l := range n.Lhs { + if refersTo(l, x) { + assigned = true + return + } + } + } + } + for _, n := range scope { + if assigned { + break + } + walk(n, ff) + } + return assigned +} + +// newPkgDot returns an ast.Expr referring to "pkg.name" at position pos. +func newPkgDot(pos token.Pos, pkg, name string) ast.Expr { + return &ast.SelectorExpr{ + X: &ast.Ident{ + NamePos: pos, + Name: pkg, + }, + Sel: &ast.Ident{ + NamePos: pos, + Name: name, + }, + } +} diff --git a/src/cmd/gofix/httpserver.go b/src/cmd/gofix/httpserver.go new file mode 100644 index 000000000..37866e88b --- /dev/null +++ b/src/cmd/gofix/httpserver.go @@ -0,0 +1,140 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "go/token" +) + +var httpserverFix = fix{ + "httpserver", + httpserver, + `Adapt http server methods and functions to changes +made to the http ResponseWriter interface. + +http://codereview.appspot.com/4245064 Hijacker +http://codereview.appspot.com/4239076 Header +http://codereview.appspot.com/4239077 Flusher +http://codereview.appspot.com/4248075 RemoteAddr, UsingTLS +`, +} + +func init() { + register(httpserverFix) +} + +func httpserver(f *ast.File) bool { + if !imports(f, "http") { + return false + } + + fixed := false + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + w, req, ok := isServeHTTP(fn) + if !ok { + continue + } + walk(fn.Body, func(n interface{}) { + // Want to replace expression sometimes, + // so record pointer to it for updating below. + ptr, ok := n.(*ast.Expr) + if ok { + n = *ptr + } + + // Look for w.UsingTLS() and w.Remoteaddr(). + call, ok := n.(*ast.CallExpr) + if !ok || (len(call.Args) != 0 && len(call.Args) != 2) { + return + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return + } + if !refersTo(sel.X, w) { + return + } + switch sel.Sel.String() { + case "Hijack": + // replace w with w.(http.Hijacker) + sel.X = &ast.TypeAssertExpr{ + X: sel.X, + Type: ast.NewIdent("http.Hijacker"), + } + fixed = true + case "Flush": + // replace w with w.(http.Flusher) + sel.X = &ast.TypeAssertExpr{ + X: sel.X, + Type: ast.NewIdent("http.Flusher"), + } + fixed = true + case "UsingTLS": + if ptr == nil { + // can only replace expression if we have pointer to it + break + } + // replace with req.TLS != nil + *ptr = &ast.BinaryExpr{ + X: &ast.SelectorExpr{ + X: ast.NewIdent(req.String()), + Sel: ast.NewIdent("TLS"), + }, + Op: token.NEQ, + Y: ast.NewIdent("nil"), + } + fixed = true + case "RemoteAddr": + if ptr == nil { + // can only replace expression if we have pointer to it + break + } + // replace with req.RemoteAddr + *ptr = &ast.SelectorExpr{ + X: ast.NewIdent(req.String()), + Sel: ast.NewIdent("RemoteAddr"), + } + fixed = true + case "SetHeader": + // replace w.SetHeader with w.Header().Set + // or w.Header().Del if second argument is "" + sel.X = &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(w.String()), + Sel: ast.NewIdent("Header"), + }, + } + sel.Sel = ast.NewIdent("Set") + if len(call.Args) == 2 && isEmptyString(call.Args[1]) { + sel.Sel = ast.NewIdent("Del") + call.Args = call.Args[:1] + } + fixed = true + } + }) + } + return fixed +} + +func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) { + for _, field := range fn.Type.Params.List { + if isPkgDot(field.Type, "http", "ResponseWriter") { + w = field.Names[0] + continue + } + if isPtrPkgDot(field.Type, "http", "Request") { + req = field.Names[0] + continue + } + } + + ok = w != nil && req != nil + return +} diff --git a/src/cmd/gofix/httpserver_test.go b/src/cmd/gofix/httpserver_test.go new file mode 100644 index 000000000..89bb4fa71 --- /dev/null +++ b/src/cmd/gofix/httpserver_test.go @@ -0,0 +1,53 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(httpserverTests) +} + +var httpserverTests = []testCase{ + { + Name: "httpserver.0", + In: `package main + +import "http" + +func f(xyz http.ResponseWriter, abc *http.Request, b string) { + xyz.SetHeader("foo", "bar") + xyz.SetHeader("baz", "") + xyz.Hijack() + xyz.Flush() + go xyz.Hijack() + defer xyz.Flush() + _ = xyz.UsingTLS() + _ = true == xyz.UsingTLS() + _ = xyz.RemoteAddr() + _ = xyz.RemoteAddr() == "hello" + if xyz.UsingTLS() { + } +} +`, + Out: `package main + +import "http" + +func f(xyz http.ResponseWriter, abc *http.Request, b string) { + xyz.Header().Set("foo", "bar") + xyz.Header().Del("baz") + xyz.(http.Hijacker).Hijack() + xyz.(http.Flusher).Flush() + go xyz.(http.Hijacker).Hijack() + defer xyz.(http.Flusher).Flush() + _ = abc.TLS != nil + _ = true == (abc.TLS != nil) + _ = abc.RemoteAddr + _ = abc.RemoteAddr == "hello" + if abc.TLS != nil { + } +} +`, + }, +} diff --git a/src/cmd/gofix/main.go b/src/cmd/gofix/main.go new file mode 100644 index 000000000..4f7e923e3 --- /dev/null +++ b/src/cmd/gofix/main.go @@ -0,0 +1,264 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "exec" + "flag" + "fmt" + "go/parser" + "go/printer" + "go/scanner" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "sort" + "strings" +) + +var ( + fset = token.NewFileSet() + exitCode = 0 +) + +var allowedRewrites = flag.String("r", "", + "restrict the rewrites to this comma-separated list") + +var allowed map[string]bool + +var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") + +func usage() { + fmt.Fprintf(os.Stderr, "usage: gofix [-diff] [-r fixname,...] [path ...]\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") + for _, f := range fixes { + fmt.Fprintf(os.Stderr, "\n%s\n", f.name) + desc := strings.TrimSpace(f.desc) + desc = strings.Replace(desc, "\n", "\n\t", -1) + fmt.Fprintf(os.Stderr, "\t%s\n", desc) + } + os.Exit(2) +} + +func main() { + sort.Sort(fixes) + + flag.Usage = usage + flag.Parse() + + if *allowedRewrites != "" { + allowed = make(map[string]bool) + for _, f := range strings.Split(*allowedRewrites, ",", -1) { + allowed[f] = true + } + } + + if flag.NArg() == 0 { + if err := processFile("standard input", true); err != nil { + report(err) + } + os.Exit(exitCode) + } + + for i := 0; i < flag.NArg(); i++ { + path := flag.Arg(i) + switch dir, err := os.Stat(path); { + case err != nil: + report(err) + case dir.IsRegular(): + if err := processFile(path, false); err != nil { + report(err) + } + case dir.IsDirectory(): + walkDir(path) + } + } + + os.Exit(exitCode) +} + +const ( + tabWidth = 8 + parserMode = parser.ParseComments + printerMode = printer.TabIndent | printer.UseSpaces +) + +var printConfig = &printer.Config{ + printerMode, + tabWidth, +} + +func processFile(filename string, useStdin bool) os.Error { + var f *os.File + var err os.Error + var fixlog bytes.Buffer + var buf bytes.Buffer + + if useStdin { + f = os.Stdin + } else { + f, err = os.Open(filename) + if err != nil { + return err + } + defer f.Close() + } + + src, err := ioutil.ReadAll(f) + if err != nil { + return err + } + + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err != nil { + return err + } + + // Apply all fixes to file. + newFile := file + fixed := false + for _, fix := range fixes { + if allowed != nil && !allowed[fix.desc] { + continue + } + if fix.f(newFile) { + fixed = true + fmt.Fprintf(&fixlog, " %s", fix.name) + + // AST changed. + // Print and parse, to update any missing scoping + // or position information for subsequent fixers. + buf.Reset() + _, err = printConfig.Fprint(&buf, fset, newFile) + if err != nil { + return err + } + newSrc := buf.Bytes() + newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) + if err != nil { + return err + } + } + } + if !fixed { + return nil + } + fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) + + // Print AST. We did that after each fix, so this appears + // redundant, but it is necessary to generate gofmt-compatible + // source code in a few cases. The official gofmt style is the + // output of the printer run on a standard AST generated by the parser, + // but the source we generated inside the loop above is the + // output of the printer run on a mangled AST generated by a fixer. + buf.Reset() + _, err = printConfig.Fprint(&buf, fset, newFile) + if err != nil { + return err + } + newSrc := buf.Bytes() + + if *doDiff { + data, err := diff(src, newSrc) + if err != nil { + return fmt.Errorf("computing diff: %s", err) + } + fmt.Printf("diff %s fixed/%s\n", filename, filename) + os.Stdout.Write(data) + return nil + } + + if useStdin { + os.Stdout.Write(newSrc) + return nil + } + + return ioutil.WriteFile(f.Name(), newSrc, 0) +} + +var gofmtBuf bytes.Buffer + +func gofmt(n interface{}) string { + gofmtBuf.Reset() + _, err := printConfig.Fprint(&gofmtBuf, fset, n) + if err != nil { + return "<" + err.String() + ">" + } + return gofmtBuf.String() +} + +func report(err os.Error) { + scanner.PrintError(os.Stderr, err) + exitCode = 2 +} + +func walkDir(path string) { + v := make(fileVisitor) + go func() { + filepath.Walk(path, v, v) + close(v) + }() + for err := range v { + if err != nil { + report(err) + } + } +} + +type fileVisitor chan os.Error + +func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { + return true +} + +func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { + if isGoFile(f) { + v <- nil // synchronize error handler + if err := processFile(path, false); err != nil { + v <- err + } + } +} + +func isGoFile(f *os.FileInfo) bool { + // ignore non-Go files + return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go") +} + +func diff(b1, b2 []byte) (data []byte, err os.Error) { + f1, err := ioutil.TempFile("", "gofix") + if err != nil { + return nil, err + } + defer os.Remove(f1.Name()) + defer f1.Close() + + f2, err := ioutil.TempFile("", "gofix") + if err != nil { + return nil, err + } + defer os.Remove(f2.Name()) + defer f2.Close() + + f1.Write(b1) + f2.Write(b2) + + diffcmd, err := exec.LookPath("diff") + if err != nil { + return nil, err + } + + c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "", + exec.DevNull, exec.Pipe, exec.MergeWithStdout) + if err != nil { + return nil, err + } + defer c.Close() + + return ioutil.ReadAll(c.Stdout) +} diff --git a/src/cmd/gofix/main_test.go b/src/cmd/gofix/main_test.go new file mode 100644 index 000000000..275778e5b --- /dev/null +++ b/src/cmd/gofix/main_test.go @@ -0,0 +1,126 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "go/ast" + "go/parser" + "go/printer" + "strings" + "testing" +) + +type testCase struct { + Name string + Fn func(*ast.File) bool + In string + Out string +} + +var testCases []testCase + +func addTestCases(t []testCase) { + testCases = append(testCases, t...) +} + +func fnop(*ast.File) bool { return false } + +func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) { + file, err := parser.ParseFile(fset, desc, in, parserMode) + if err != nil { + t.Errorf("%s: parsing: %v", desc, err) + return + } + + var buf bytes.Buffer + buf.Reset() + _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + if err != nil { + t.Errorf("%s: printing: %v", desc, err) + return + } + if s := buf.String(); in != s && fn != fnop { + t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", + desc, desc, in, desc, s) + tdiff(t, in, s) + return + } + + if fn == nil { + for _, fix := range fixes { + if fix.f(file) { + fixed = true + } + } + } else { + fixed = fn(file) + } + + buf.Reset() + _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + if err != nil { + t.Errorf("%s: printing: %v", desc, err) + return + } + + return buf.String(), fixed, true +} + +func TestRewrite(t *testing.T) { + for _, tt := range testCases { + // Apply fix: should get tt.Out. + out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In) + if !ok { + continue + } + + // reformat to get printing right + out, _, ok = parseFixPrint(t, fnop, tt.Name, out) + if !ok { + continue + } + + if out != tt.Out { + t.Errorf("%s: incorrect output.\n", tt.Name) + if !strings.HasPrefix(tt.Name, "testdata/") { + t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out) + } + tdiff(t, out, tt.Out) + continue + } + + if changed := out != tt.In; changed != fixed { + t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed) + continue + } + + // Should not change if run again. + out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out) + if !ok { + continue + } + + if fixed2 { + t.Errorf("%s: applied fixes during second round", tt.Name) + continue + } + + if out2 != out { + t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s", + tt.Name, out, out2) + tdiff(t, out, out2) + } + } +} + +func tdiff(t *testing.T, a, b string) { + data, err := diff([]byte(a), []byte(b)) + if err != nil { + t.Error(err) + return + } + t.Error(string(data)) +} diff --git a/src/cmd/gofix/netdial.go b/src/cmd/gofix/netdial.go new file mode 100644 index 000000000..afa98953b --- /dev/null +++ b/src/cmd/gofix/netdial.go @@ -0,0 +1,114 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" +) + +var netdialFix = fix{ + "netdial", + netdial, + `Adapt 3-argument calls of net.Dial to use 2-argument form. + +http://codereview.appspot.com/4244055 +`, +} + +var tlsdialFix = fix{ + "tlsdial", + tlsdial, + `Adapt 4-argument calls of tls.Dial to use 3-argument form. + +http://codereview.appspot.com/4244055 +`, +} + +var netlookupFix = fix{ + "netlookup", + netlookup, + `Adapt 3-result calls to net.LookupHost to use 2-result form. + +http://codereview.appspot.com/4244055 +`, +} + +func init() { + register(netdialFix) + register(tlsdialFix) + register(netlookupFix) +} + +func netdial(f *ast.File) bool { + if !imports(f, "net") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + call, ok := n.(*ast.CallExpr) + if !ok || !isPkgDot(call.Fun, "net", "Dial") || len(call.Args) != 3 { + return + } + // net.Dial(a, "", b) -> net.Dial(a, b) + if !isEmptyString(call.Args[1]) { + warn(call.Pos(), "call to net.Dial with non-empty second argument") + return + } + call.Args[1] = call.Args[2] + call.Args = call.Args[:2] + fixed = true + }) + return fixed +} + +func tlsdial(f *ast.File) bool { + if !imports(f, "crypto/tls") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + call, ok := n.(*ast.CallExpr) + if !ok || !isPkgDot(call.Fun, "tls", "Dial") || len(call.Args) != 4 { + return + } + // tls.Dial(a, "", b, c) -> tls.Dial(a, b, c) + if !isEmptyString(call.Args[1]) { + warn(call.Pos(), "call to tls.Dial with non-empty second argument") + return + } + call.Args[1] = call.Args[2] + call.Args[2] = call.Args[3] + call.Args = call.Args[:3] + fixed = true + }) + return fixed +} + +func netlookup(f *ast.File) bool { + if !imports(f, "net") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + as, ok := n.(*ast.AssignStmt) + if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 { + return + } + call, ok := as.Rhs[0].(*ast.CallExpr) + if !ok || !isPkgDot(call.Fun, "net", "LookupHost") { + return + } + if !isBlank(as.Lhs[2]) { + warn(as.Pos(), "call to net.LookupHost expecting cname; use net.LookupCNAME") + return + } + as.Lhs = as.Lhs[:2] + fixed = true + }) + return fixed +} diff --git a/src/cmd/gofix/netdial_test.go b/src/cmd/gofix/netdial_test.go new file mode 100644 index 000000000..272aa526a --- /dev/null +++ b/src/cmd/gofix/netdial_test.go @@ -0,0 +1,51 @@ +package main + +func init() { + addTestCases(netdialTests) +} + +var netdialTests = []testCase{ + { + Name: "netdial.0", + In: `package main + +import "net" + +func f() { + c, err := net.Dial(net, "", addr) + c, err = net.Dial(net, "", addr) +} +`, + Out: `package main + +import "net" + +func f() { + c, err := net.Dial(net, addr) + c, err = net.Dial(net, addr) +} +`, + }, + + { + Name: "netlookup.0", + In: `package main + +import "net" + +func f() { + foo, bar, _ := net.LookupHost(host) + foo, bar, _ = net.LookupHost(host) +} +`, + Out: `package main + +import "net" + +func f() { + foo, bar := net.LookupHost(host) + foo, bar = net.LookupHost(host) +} +`, + }, +} diff --git a/src/cmd/gofix/osopen.go b/src/cmd/gofix/osopen.go new file mode 100644 index 000000000..8eb5d0655 --- /dev/null +++ b/src/cmd/gofix/osopen.go @@ -0,0 +1,122 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" +) + +var osopenFix = fix{ + "osopen", + osopen, + `Adapt os.Open calls to new, easier API and rename O_CREAT O_CREATE. + + http://codereview.appspot.com/4357052 +`, +} + +func init() { + register(osopenFix) +} + +func osopen(f *ast.File) bool { + if !imports(f, "os") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + // Rename O_CREAT to O_CREATE. + if expr, ok := n.(ast.Expr); ok && isPkgDot(expr, "os", "O_CREAT") { + expr.(*ast.SelectorExpr).Sel.Name = "O_CREATE" + return + } + + // Fix up calls to Open. + call, ok := n.(*ast.CallExpr) + if !ok || len(call.Args) != 3 { + return + } + if !isPkgDot(call.Fun, "os", "Open") { + return + } + sel := call.Fun.(*ast.SelectorExpr) + args := call.Args + // os.Open(a, os.O_RDONLY, c) -> os.Open(a) + if isPkgDot(args[1], "os", "O_RDONLY") || isPkgDot(args[1], "syscall", "O_RDONLY") { + call.Args = call.Args[0:1] + fixed = true + return + } + // os.Open(a, createlike_flags, c) -> os.Create(a, c) + if isCreateFlag(args[1]) { + sel.Sel.Name = "Create" + if !isSimplePerm(args[2]) { + warn(sel.Pos(), "rewrote os.Open to os.Create with permission not 0666") + } + call.Args = args[0:1] + fixed = true + return + } + // Fallback: os.Open(a, b, c) -> os.OpenFile(a, b, c) + sel.Sel.Name = "OpenFile" + fixed = true + }) + return fixed +} + +func isCreateFlag(flag ast.Expr) bool { + foundCreate := false + foundTrunc := false + // OR'ing of flags: is O_CREATE on? + or | would be fine; we just look for os.O_CREATE + // and don't worry about the actual operator. + p := flag.Pos() + for { + lhs := flag + expr, isBinary := flag.(*ast.BinaryExpr) + if isBinary { + lhs = expr.Y + } + sel, ok := lhs.(*ast.SelectorExpr) + if !ok || !isTopName(sel.X, "os") { + return false + } + switch sel.Sel.Name { + case "O_CREATE": + foundCreate = true + case "O_TRUNC": + foundTrunc = true + case "O_RDONLY", "O_WRONLY", "O_RDWR": + // okay + default: + // Unexpected flag, like O_APPEND or O_EXCL. + // Be conservative and do not rewrite. + return false + } + if !isBinary { + break + } + flag = expr.X + } + if !foundCreate { + return false + } + if !foundTrunc { + warn(p, "rewrote os.Open with O_CREATE but not O_TRUNC to os.Create") + } + return foundCreate +} + +func isSimplePerm(perm ast.Expr) bool { + basicLit, ok := perm.(*ast.BasicLit) + if !ok { + return false + } + switch basicLit.Value { + case "0666": + return true + } + return false +} diff --git a/src/cmd/gofix/osopen_test.go b/src/cmd/gofix/osopen_test.go new file mode 100644 index 000000000..43ddd1a40 --- /dev/null +++ b/src/cmd/gofix/osopen_test.go @@ -0,0 +1,59 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(osopenTests) +} + +var osopenTests = []testCase{ + { + Name: "osopen.0", + In: `package main + +import ( + "os" +) + +func f() { + os.OpenFile(a, b, c) + os.Open(a, os.O_RDONLY, 0) + os.Open(a, os.O_RDONLY, 0666) + os.Open(a, os.O_RDWR, 0) + os.Open(a, os.O_CREAT, 0666) + os.Open(a, os.O_CREAT|os.O_TRUNC, 0664) + os.Open(a, os.O_CREATE, 0666) + os.Open(a, os.O_CREATE|os.O_TRUNC, 0664) + os.Open(a, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + os.Open(a, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + os.Open(a, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + os.Open(a, os.O_SURPRISE|os.O_CREATE, 0666) + _ = os.O_CREAT +} +`, + Out: `package main + +import ( + "os" +) + +func f() { + os.OpenFile(a, b, c) + os.Open(a) + os.Open(a) + os.OpenFile(a, os.O_RDWR, 0) + os.Create(a) + os.Create(a) + os.Create(a) + os.Create(a) + os.Create(a) + os.OpenFile(a, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + os.OpenFile(a, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + os.OpenFile(a, os.O_SURPRISE|os.O_CREATE, 0666) + _ = os.O_CREATE +} +`, + }, +} diff --git a/src/cmd/gofix/procattr.go b/src/cmd/gofix/procattr.go new file mode 100644 index 000000000..0e2190b1f --- /dev/null +++ b/src/cmd/gofix/procattr.go @@ -0,0 +1,61 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "go/token" +) + +var procattrFix = fix{ + "procattr", + procattr, + `Adapt calls to os.StartProcess to use new ProcAttr type. + +http://codereview.appspot.com/4253052 +`, +} + +func init() { + register(procattrFix) +} + +func procattr(f *ast.File) bool { + if !imports(f, "os") && !imports(f, "syscall") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + call, ok := n.(*ast.CallExpr) + if !ok || len(call.Args) != 5 { + return + } + var pkg string + if isPkgDot(call.Fun, "os", "StartProcess") { + pkg = "os" + } else if isPkgDot(call.Fun, "syscall", "StartProcess") { + pkg = "syscall" + } else { + return + } + // os.StartProcess(a, b, c, d, e) -> os.StartProcess(a, b, &os.ProcAttr{Env: c, Dir: d, Files: e}) + lit := &ast.CompositeLit{Type: ast.NewIdent(pkg + ".ProcAttr")} + env, dir, files := call.Args[2], call.Args[3], call.Args[4] + if !isName(env, "nil") && !isCall(env, "os", "Environ") { + lit.Elts = append(lit.Elts, &ast.KeyValueExpr{Key: ast.NewIdent("Env"), Value: env}) + } + if !isEmptyString(dir) { + lit.Elts = append(lit.Elts, &ast.KeyValueExpr{Key: ast.NewIdent("Dir"), Value: dir}) + } + if !isName(files, "nil") { + lit.Elts = append(lit.Elts, &ast.KeyValueExpr{Key: ast.NewIdent("Files"), Value: files}) + } + call.Args[2] = &ast.UnaryExpr{Op: token.AND, X: lit} + call.Args = call.Args[:3] + fixed = true + }) + return fixed +} diff --git a/src/cmd/gofix/procattr_test.go b/src/cmd/gofix/procattr_test.go new file mode 100644 index 000000000..b973b9684 --- /dev/null +++ b/src/cmd/gofix/procattr_test.go @@ -0,0 +1,74 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(procattrTests) +} + +var procattrTests = []testCase{ + { + Name: "procattr.0", + In: `package main + +import ( + "os" + "syscall" +) + +func f() { + os.StartProcess(a, b, c, d, e) + os.StartProcess(a, b, os.Environ(), d, e) + os.StartProcess(a, b, nil, d, e) + os.StartProcess(a, b, c, "", e) + os.StartProcess(a, b, c, d, nil) + os.StartProcess(a, b, nil, "", nil) + + os.StartProcess( + a, + b, + c, + d, + e, + ) + + syscall.StartProcess(a, b, c, d, e) + syscall.StartProcess(a, b, os.Environ(), d, e) + syscall.StartProcess(a, b, nil, d, e) + syscall.StartProcess(a, b, c, "", e) + syscall.StartProcess(a, b, c, d, nil) + syscall.StartProcess(a, b, nil, "", nil) +} +`, + Out: `package main + +import ( + "os" + "syscall" +) + +func f() { + os.StartProcess(a, b, &os.ProcAttr{Env: c, Dir: d, Files: e}) + os.StartProcess(a, b, &os.ProcAttr{Dir: d, Files: e}) + os.StartProcess(a, b, &os.ProcAttr{Dir: d, Files: e}) + os.StartProcess(a, b, &os.ProcAttr{Env: c, Files: e}) + os.StartProcess(a, b, &os.ProcAttr{Env: c, Dir: d}) + os.StartProcess(a, b, &os.ProcAttr{}) + + os.StartProcess( + a, + b, &os.ProcAttr{Env: c, Dir: d, Files: e}, + ) + + syscall.StartProcess(a, b, &syscall.ProcAttr{Env: c, Dir: d, Files: e}) + syscall.StartProcess(a, b, &syscall.ProcAttr{Dir: d, Files: e}) + syscall.StartProcess(a, b, &syscall.ProcAttr{Dir: d, Files: e}) + syscall.StartProcess(a, b, &syscall.ProcAttr{Env: c, Files: e}) + syscall.StartProcess(a, b, &syscall.ProcAttr{Env: c, Dir: d}) + syscall.StartProcess(a, b, &syscall.ProcAttr{}) +} +`, + }, +} diff --git a/src/cmd/gofix/reflect.go b/src/cmd/gofix/reflect.go new file mode 100644 index 000000000..74ddb398f --- /dev/null +++ b/src/cmd/gofix/reflect.go @@ -0,0 +1,843 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(rsc): Once there is better support for writing +// multi-package commands, this should really be in +// its own package, and then we can drop all the "reflect" +// prefixes on the global variables and functions. + +package main + +import ( + "go/ast" + "go/token" + "strings" +) + +var reflectFix = fix{ + "reflect", + reflectFn, + `Adapt code to new reflect API. + +http://codereview.appspot.com/4281055 +`, +} + +func init() { + register(reflectFix) +} + +// The reflect API change dropped the concrete types *reflect.ArrayType etc. +// Any type assertions prior to method calls can be deleted: +// x.(*reflect.ArrayType).Len() -> x.Len() +// +// Any type checks can be replaced by assignment and check of Kind: +// x, y := z.(*reflect.ArrayType) +// -> +// x := z +// y := x.Kind() == reflect.Array +// +// If z is an ordinary variable name and x is not subsequently assigned to, +// references to x can be replaced by z and the assignment deleted. +// We only bother if x and z are the same name. +// If y is not subsequently assigned to and neither is x, references to +// y can be replaced by its expression. We only bother when there is +// just one use or when the use appears in an if clause. +// +// Not all type checks result in a single Kind check. The rewrite of the type check for +// reflect.ArrayOrSliceType checks x.Kind() against reflect.Array and reflect.Slice. +// The rewrite for *reflect.IntType checks againt Int, Int8, Int16, Int32, Int64. +// The rewrite for *reflect.UintType adds Uintptr. +// +// A type switch turns into an assignment and a switch on Kind: +// switch x := y.(type) { +// case reflect.ArrayOrSliceType: +// ... +// case *reflect.ChanType: +// ... +// case *reflect.IntType: +// ... +// } +// -> +// switch x := y; x.Kind() { +// case reflect.Array, reflect.Slice: +// ... +// case reflect.Chan: +// ... +// case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: +// ... +// } +// +// The same simplification applies: we drop x := x if x is not assigned +// to in the switch cases. +// +// Because the type check assignment includes a type assertion in its +// syntax and the rewrite traversal is bottom up, we must do a pass to +// rewrite the type check assignments and then a separate pass to +// rewrite the type assertions. +// +// The same process applies to the API changes for reflect.Value. +// +// For both cases, but especially Value, the code needs to be aware +// of the type of a receiver when rewriting a method call. For example, +// x.(*reflect.ArrayValue).Elem(i) becomes x.Index(i) while +// x.(*reflect.MapValue).Elem(v) becomes x.MapIndex(v). +// In general, reflectFn needs to know the type of the receiver expression. +// In most cases (and in all the cases in the Go source tree), the toy +// type checker in typecheck.go provides enough information for gofix +// to make the rewrite. If gofix misses a rewrite, the code that is left over +// will not compile, so it will be noticed immediately. + +func reflectFn(f *ast.File) bool { + if !imports(f, "reflect") { + return false + } + + fixed := false + + // Rewrite names in method calls. + // Needs basic type information (see above). + typeof := typecheck(reflectTypeConfig, f) + walk(f, func(n interface{}) { + switch n := n.(type) { + case *ast.SelectorExpr: + typ := typeof[n.X] + if m := reflectRewriteMethod[typ]; m != nil { + if replace := m[n.Sel.Name]; replace != "" { + n.Sel.Name = replace + fixed = true + return + } + } + + // For all reflect Values, replace SetValue with Set. + if isReflectValue[typ] && n.Sel.Name == "SetValue" { + n.Sel.Name = "Set" + fixed = true + return + } + + // Replace reflect.MakeZero with reflect.Zero. + if isPkgDot(n, "reflect", "MakeZero") { + n.Sel.Name = "Zero" + fixed = true + return + } + } + }) + + // Replace PtrValue's PointTo(x) with Set(x.Addr()). + walk(f, func(n interface{}) { + call, ok := n.(*ast.CallExpr) + if !ok || len(call.Args) != 1 { + return + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "PointTo" { + return + } + typ := typeof[sel.X] + if typ != "*reflect.PtrValue" { + return + } + sel.Sel.Name = "Set" + if !isTopName(call.Args[0], "nil") { + call.Args[0] = &ast.SelectorExpr{ + X: call.Args[0], + Sel: ast.NewIdent("Addr()"), + } + } + fixed = true + }) + + // Fix type switches. + walk(f, func(n interface{}) { + if reflectFixSwitch(n) { + fixed = true + } + }) + + // Fix type assertion checks (multiple assignment statements). + // Have to work on the statement context (statement list or if statement) + // so that we can insert an extra statement occasionally. + // Ignoring for and switch because they don't come up in + // typical code. + walk(f, func(n interface{}) { + switch n := n.(type) { + case *[]ast.Stmt: + // v is the replacement statement list. + var v []ast.Stmt + insert := func(x ast.Stmt) { + v = append(v, x) + } + for i, x := range *n { + // Tentatively append to v; if we rewrite x + // we'll have to update the entry, so remember + // the index. + j := len(v) + v = append(v, x) + if reflectFixTypecheck(&x, insert, (*n)[i+1:]) { + // reflectFixTypecheck may have overwritten x. + // Update the entry we appended just before the call. + v[j] = x + fixed = true + } + } + *n = v + case *ast.IfStmt: + x := &ast.ExprStmt{n.Cond} + if reflectFixTypecheck(&n.Init, nil, []ast.Stmt{x, n.Body, n.Else}) { + n.Cond = x.X + fixed = true + } + } + }) + + // Warn about any typecheck statements that we missed. + walk(f, reflectWarnTypecheckStmt) + + // Now that those are gone, fix remaining type assertions. + // Delayed because the type checks have + // type assertions as part of their syntax. + walk(f, func(n interface{}) { + if reflectFixAssert(n) { + fixed = true + } + }) + + // Now that the type assertions are gone, rewrite remaining + // references to specific reflect types to use the general ones. + walk(f, func(n interface{}) { + ptr, ok := n.(*ast.Expr) + if !ok { + return + } + nn := *ptr + typ := reflectType(nn) + if typ == "" { + return + } + if strings.HasSuffix(typ, "Type") { + *ptr = newPkgDot(nn.Pos(), "reflect", "Type") + } else { + *ptr = newPkgDot(nn.Pos(), "reflect", "Value") + } + fixed = true + }) + + // Rewrite v.Set(nil) to v.Set(reflect.MakeZero(v.Type())). + walk(f, func(n interface{}) { + call, ok := n.(*ast.CallExpr) + if !ok || len(call.Args) != 1 || !isTopName(call.Args[0], "nil") { + return + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || !isReflectValue[typeof[sel.X]] || sel.Sel.Name != "Set" { + return + } + call.Args[0] = &ast.CallExpr{ + Fun: newPkgDot(call.Args[0].Pos(), "reflect", "Zero"), + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: sel.X, + Sel: &ast.Ident{Name: "Type"}, + }, + }, + }, + } + fixed = true + }) + + // Rewrite v != nil to v.IsValid(). + // Rewrite nil used as reflect.Value (in function argument or return) to reflect.Value{}. + walk(f, func(n interface{}) { + ptr, ok := n.(*ast.Expr) + if !ok { + return + } + if isTopName(*ptr, "nil") && isReflectValue[typeof[*ptr]] { + *ptr = ast.NewIdent("reflect.Value{}") + fixed = true + return + } + nn, ok := (*ptr).(*ast.BinaryExpr) + if !ok || (nn.Op != token.EQL && nn.Op != token.NEQ) || !isTopName(nn.Y, "nil") || !isReflectValue[typeof[nn.X]] { + return + } + var call ast.Expr = &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: nn.X, + Sel: &ast.Ident{Name: "IsValid"}, + }, + } + if nn.Op == token.EQL { + call = &ast.UnaryExpr{Op: token.NOT, X: call} + } + *ptr = call + fixed = true + }) + + return fixed +} + +// reflectFixSwitch rewrites *n (if n is an *ast.Stmt) corresponding +// to a type switch. +func reflectFixSwitch(n interface{}) bool { + ptr, ok := n.(*ast.Stmt) + if !ok { + return false + } + n = *ptr + + ts, ok := n.(*ast.TypeSwitchStmt) + if !ok { + return false + } + + // Are any switch cases referring to reflect types? + // (That is, is this an old reflect type switch?) + for _, cas := range ts.Body.List { + for _, typ := range cas.(*ast.CaseClause).List { + if reflectType(typ) != "" { + goto haveReflect + } + } + } + return false + +haveReflect: + // Now we know it's an old reflect type switch. Prepare the new version, + // but don't replace or edit the original until we're sure of success. + + // Figure out the initializer statement, if any, and the receiver for the Kind call. + var init ast.Stmt + var rcvr ast.Expr + + init = ts.Init + switch n := ts.Assign.(type) { + default: + warn(ts.Pos(), "unexpected form in type switch") + return false + + case *ast.AssignStmt: + as := n + ta := as.Rhs[0].(*ast.TypeAssertExpr) + x := isIdent(as.Lhs[0]) + z := isIdent(ta.X) + + if isBlank(x) || x != nil && z != nil && x.Name == z.Name && !assignsTo(x, ts.Body.List) { + // Can drop the variable creation. + rcvr = ta.X + } else { + // Need to use initialization statement. + if init != nil { + warn(ts.Pos(), "cannot rewrite reflect type switch with initializing statement") + return false + } + init = &ast.AssignStmt{ + Lhs: []ast.Expr{as.Lhs[0]}, + TokPos: as.TokPos, + Tok: token.DEFINE, + Rhs: []ast.Expr{ta.X}, + } + rcvr = as.Lhs[0] + } + + case *ast.ExprStmt: + rcvr = n.X.(*ast.TypeAssertExpr).X + } + + // Prepare rewritten type switch (see large comment above for form). + sw := &ast.SwitchStmt{ + Switch: ts.Switch, + Init: init, + Tag: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: rcvr, + Sel: &ast.Ident{ + NamePos: rcvr.End(), + Name: "Kind", + Obj: nil, + }, + }, + Lparen: rcvr.End(), + Rparen: rcvr.End(), + }, + Body: &ast.BlockStmt{ + Lbrace: ts.Body.Lbrace, + List: nil, // to be filled in + Rbrace: ts.Body.Rbrace, + }, + } + + // Translate cases. + for _, tcas := range ts.Body.List { + tcas := tcas.(*ast.CaseClause) + cas := &ast.CaseClause{ + Case: tcas.Case, + Colon: tcas.Colon, + Body: tcas.Body, + } + for _, t := range tcas.List { + if isTopName(t, "nil") { + cas.List = append(cas.List, newPkgDot(t.Pos(), "reflect", "Invalid")) + continue + } + + typ := reflectType(t) + if typ == "" { + warn(t.Pos(), "cannot rewrite reflect type switch case with non-reflect type %s", gofmt(t)) + cas.List = append(cas.List, t) + continue + } + + for _, k := range reflectKind[typ] { + cas.List = append(cas.List, newPkgDot(t.Pos(), "reflect", k)) + } + } + sw.Body.List = append(sw.Body.List, cas) + } + + // Everything worked. Rewrite AST. + *ptr = sw + return true +} + +// Rewrite x, y = z.(T) into +// x = z +// y = x.Kind() == K +// as described in the long comment above. +// +// If insert != nil, it can be called to insert a statement after *ptr in its block. +// If insert == nil, insertion is not possible. +// At most one call to insert is allowed. +// +// Scope gives the statements for which a declaration +// in *ptr would be in scope. +// +// The result is true of the statement was rewritten. +// +func reflectFixTypecheck(ptr *ast.Stmt, insert func(ast.Stmt), scope []ast.Stmt) bool { + st := *ptr + as, ok := st.(*ast.AssignStmt) + if !ok || len(as.Lhs) != 2 || len(as.Rhs) != 1 { + return false + } + + ta, ok := as.Rhs[0].(*ast.TypeAssertExpr) + if !ok { + return false + } + typ := reflectType(ta.Type) + if typ == "" { + return false + } + + // Have x, y := z.(t). + x := isIdent(as.Lhs[0]) + y := isIdent(as.Lhs[1]) + z := isIdent(ta.X) + + // First step is x := z, unless it's x := x and the resulting x is never reassigned. + // rcvr is the x in x.Kind(). + var rcvr ast.Expr + if isBlank(x) || + as.Tok == token.DEFINE && x != nil && z != nil && x.Name == z.Name && !assignsTo(x, scope) { + // Can drop the statement. + // If we need to insert a statement later, now we have a slot. + *ptr = &ast.EmptyStmt{} + insert = func(x ast.Stmt) { *ptr = x } + rcvr = ta.X + } else { + *ptr = &ast.AssignStmt{ + Lhs: []ast.Expr{as.Lhs[0]}, + TokPos: as.TokPos, + Tok: as.Tok, + Rhs: []ast.Expr{ta.X}, + } + rcvr = as.Lhs[0] + } + + // Prepare x.Kind() == T expression appropriate to t. + // If x is not a simple identifier, warn that we might be + // reevaluating x. + if x == nil { + warn(as.Pos(), "rewrite reevaluates expr with possible side effects: %s", gofmt(as.Lhs[0])) + } + yExpr, yNotExpr := reflectKindEq(rcvr, reflectKind[typ]) + + // Second step is y := x.Kind() == T, unless it's only used once + // or we have no way to insert that statement. + var yStmt *ast.AssignStmt + if as.Tok == token.DEFINE && countUses(y, scope) <= 1 || insert == nil { + // Can drop the statement and use the expression directly. + rewriteUses(y, + func(token.Pos) ast.Expr { return yExpr }, + func(token.Pos) ast.Expr { return yNotExpr }, + scope) + } else { + yStmt = &ast.AssignStmt{ + Lhs: []ast.Expr{as.Lhs[1]}, + TokPos: as.End(), + Tok: as.Tok, + Rhs: []ast.Expr{yExpr}, + } + insert(yStmt) + } + return true +} + +// reflectKindEq returns the expression z.Kind() == kinds[0] || z.Kind() == kinds[1] || ... +// and its negation. +// The qualifier "reflect." is inserted before each kinds[i] expression. +func reflectKindEq(z ast.Expr, kinds []string) (ast.Expr, ast.Expr) { + n := len(kinds) + if n == 1 { + y := &ast.BinaryExpr{ + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: z, + Sel: ast.NewIdent("Kind"), + }, + }, + Op: token.EQL, + Y: newPkgDot(token.NoPos, "reflect", kinds[0]), + } + ynot := &ast.BinaryExpr{ + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: z, + Sel: ast.NewIdent("Kind"), + }, + }, + Op: token.NEQ, + Y: newPkgDot(token.NoPos, "reflect", kinds[0]), + } + return y, ynot + } + + x, xnot := reflectKindEq(z, kinds[0:n-1]) + y, ynot := reflectKindEq(z, kinds[n-1:]) + + or := &ast.BinaryExpr{ + X: x, + Op: token.LOR, + Y: y, + } + andnot := &ast.BinaryExpr{ + X: xnot, + Op: token.LAND, + Y: ynot, + } + return or, andnot +} + +// if x represents a known old reflect type/value like *reflect.PtrType or reflect.ArrayOrSliceValue, +// reflectType returns the string form of that type. +func reflectType(x ast.Expr) string { + ptr, ok := x.(*ast.StarExpr) + if ok { + x = ptr.X + } + + sel, ok := x.(*ast.SelectorExpr) + if !ok || !isName(sel.X, "reflect") { + return "" + } + + var s = "reflect." + if ptr != nil { + s = "*reflect." + } + s += sel.Sel.Name + + if reflectKind[s] != nil { + return s + } + return "" +} + +// reflectWarnTypecheckStmt warns about statements +// of the form x, y = z.(T) for any old reflect type T. +// The last pass should have gotten them all, and if it didn't, +// the next pass is going to turn them into x, y = z. +func reflectWarnTypecheckStmt(n interface{}) { + as, ok := n.(*ast.AssignStmt) + if !ok || len(as.Lhs) != 2 || len(as.Rhs) != 1 { + return + } + ta, ok := as.Rhs[0].(*ast.TypeAssertExpr) + if !ok || reflectType(ta.Type) == "" { + return + } + warn(n.(ast.Node).Pos(), "unfixed reflect type check") +} + +// reflectFixAssert rewrites x.(T) to x for any old reflect type T. +func reflectFixAssert(n interface{}) bool { + ptr, ok := n.(*ast.Expr) + if ok { + ta, ok := (*ptr).(*ast.TypeAssertExpr) + if ok && reflectType(ta.Type) != "" { + *ptr = ta.X + return true + } + } + return false +} + +// Tables describing the transformations. + +// Description of old reflect API for partial type checking. +// We pretend the Elem method is on Type and Value instead +// of enumerating all the types it is actually on. +// Also, we pretend that ArrayType etc embeds Type for the +// purposes of describing the API. (In fact they embed commonType, +// which implements Type.) +var reflectTypeConfig = &TypeConfig{ + Type: map[string]*Type{ + "reflect.ArrayOrSliceType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.ArrayOrSliceValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.ArrayType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.ArrayValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.BoolType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.BoolValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.ChanType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.ChanValue": &Type{ + Method: map[string]string{ + "Recv": "func() (reflect.Value, bool)", + "TryRecv": "func() (reflect.Value, bool)", + }, + Embed: []string{"reflect.Value"}, + }, + "reflect.ComplexType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.ComplexValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.FloatType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.FloatValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.FuncType": &Type{ + Method: map[string]string{ + "In": "func(int) reflect.Type", + "Out": "func(int) reflect.Type", + }, + Embed: []string{"reflect.Type"}, + }, + "reflect.FuncValue": &Type{ + Method: map[string]string{ + "Call": "func([]reflect.Value) []reflect.Value", + }, + }, + "reflect.IntType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.IntValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.InterfaceType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.InterfaceValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.MapType": &Type{ + Method: map[string]string{ + "Key": "func() reflect.Type", + }, + Embed: []string{"reflect.Type"}, + }, + "reflect.MapValue": &Type{ + Method: map[string]string{ + "Keys": "func() []reflect.Value", + }, + Embed: []string{"reflect.Value"}, + }, + "reflect.Method": &Type{ + Field: map[string]string{ + "Type": "*reflect.FuncType", + "Func": "*reflect.FuncValue", + }, + }, + "reflect.PtrType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.PtrValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.SliceType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.SliceValue": &Type{ + Method: map[string]string{ + "Slice": "func(int, int) *reflect.SliceValue", + }, + Embed: []string{"reflect.Value"}, + }, + "reflect.StringType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.StringValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.StructField": &Type{ + Field: map[string]string{ + "Type": "reflect.Type", + }, + }, + "reflect.StructType": &Type{ + Method: map[string]string{ + "Field": "func() reflect.StructField", + "FieldByIndex": "func() reflect.StructField", + "FieldByName": "func() reflect.StructField,bool", + "FieldByNameFunc": "func() reflect.StructField,bool", + }, + Embed: []string{"reflect.Type"}, + }, + "reflect.StructValue": &Type{ + Method: map[string]string{ + "Field": "func() reflect.Value", + "FieldByIndex": "func() reflect.Value", + "FieldByName": "func() reflect.Value", + "FieldByNameFunc": "func() reflect.Value", + }, + Embed: []string{"reflect.Value"}, + }, + "reflect.Type": &Type{ + Method: map[string]string{ + "Elem": "func() reflect.Type", + "Method": "func() reflect.Method", + }, + }, + "reflect.UintType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.UintValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.UnsafePointerType": &Type{Embed: []string{"reflect.Type"}}, + "reflect.UnsafePointerValue": &Type{Embed: []string{"reflect.Value"}}, + "reflect.Value": &Type{ + Method: map[string]string{ + "Addr": "func() *reflect.PtrValue", + "Elem": "func() reflect.Value", + "Method": "func() *reflect.FuncValue", + "SetValue": "func(reflect.Value)", + }, + }, + }, + Func: map[string]string{ + "reflect.Append": "*reflect.SliceValue", + "reflect.AppendSlice": "*reflect.SliceValue", + "reflect.Indirect": "reflect.Value", + "reflect.MakeSlice": "*reflect.SliceValue", + "reflect.MakeChan": "*reflect.ChanValue", + "reflect.MakeMap": "*reflect.MapValue", + "reflect.MakeZero": "reflect.Value", + "reflect.NewValue": "reflect.Value", + "reflect.PtrTo": "*reflect.PtrType", + "reflect.Typeof": "reflect.Type", + }, +} + +var reflectRewriteMethod = map[string]map[string]string{ + // The type API didn't change much. + "*reflect.ChanType": {"Dir": "ChanDir"}, + "*reflect.FuncType": {"DotDotDot": "IsVariadic"}, + + // The value API has longer names to disambiguate + // methods with different signatures. + "reflect.ArrayOrSliceValue": { // interface, not pointer + "Elem": "Index", + }, + "*reflect.ArrayValue": { + "Elem": "Index", + }, + "*reflect.BoolValue": { + "Get": "Bool", + "Set": "SetBool", + }, + "*reflect.ChanValue": { + "Get": "Pointer", + }, + "*reflect.ComplexValue": { + "Get": "Complex", + "Set": "SetComplex", + "Overflow": "OverflowComplex", + }, + "*reflect.FloatValue": { + "Get": "Float", + "Set": "SetFloat", + "Overflow": "OverflowFloat", + }, + "*reflect.FuncValue": { + "Get": "Pointer", + }, + "*reflect.IntValue": { + "Get": "Int", + "Set": "SetInt", + "Overflow": "OverflowInt", + }, + "*reflect.InterfaceValue": { + "Get": "InterfaceData", + }, + "*reflect.MapValue": { + "Elem": "MapIndex", + "Get": "Pointer", + "Keys": "MapKeys", + "SetElem": "SetMapIndex", + }, + "*reflect.PtrValue": { + "Get": "Pointer", + }, + "*reflect.SliceValue": { + "Elem": "Index", + "Get": "Pointer", + }, + "*reflect.StringValue": { + "Get": "String", + "Set": "SetString", + }, + "*reflect.UintValue": { + "Get": "Uint", + "Set": "SetUint", + "Overflow": "OverflowUint", + }, + "*reflect.UnsafePointerValue": { + "Get": "Pointer", + "Set": "SetPointer", + }, +} + +var reflectKind = map[string][]string{ + "reflect.ArrayOrSliceType": {"Array", "Slice"}, // interface, not pointer + "*reflect.ArrayType": {"Array"}, + "*reflect.BoolType": {"Bool"}, + "*reflect.ChanType": {"Chan"}, + "*reflect.ComplexType": {"Complex64", "Complex128"}, + "*reflect.FloatType": {"Float32", "Float64"}, + "*reflect.FuncType": {"Func"}, + "*reflect.IntType": {"Int", "Int8", "Int16", "Int32", "Int64"}, + "*reflect.InterfaceType": {"Interface"}, + "*reflect.MapType": {"Map"}, + "*reflect.PtrType": {"Ptr"}, + "*reflect.SliceType": {"Slice"}, + "*reflect.StringType": {"String"}, + "*reflect.StructType": {"Struct"}, + "*reflect.UintType": {"Uint", "Uint8", "Uint16", "Uint32", "Uint64", "Uintptr"}, + "*reflect.UnsafePointerType": {"UnsafePointer"}, + + "reflect.ArrayOrSliceValue": {"Array", "Slice"}, // interface, not pointer + "*reflect.ArrayValue": {"Array"}, + "*reflect.BoolValue": {"Bool"}, + "*reflect.ChanValue": {"Chan"}, + "*reflect.ComplexValue": {"Complex64", "Complex128"}, + "*reflect.FloatValue": {"Float32", "Float64"}, + "*reflect.FuncValue": {"Func"}, + "*reflect.IntValue": {"Int", "Int8", "Int16", "Int32", "Int64"}, + "*reflect.InterfaceValue": {"Interface"}, + "*reflect.MapValue": {"Map"}, + "*reflect.PtrValue": {"Ptr"}, + "*reflect.SliceValue": {"Slice"}, + "*reflect.StringValue": {"String"}, + "*reflect.StructValue": {"Struct"}, + "*reflect.UintValue": {"Uint", "Uint8", "Uint16", "Uint32", "Uint64", "Uintptr"}, + "*reflect.UnsafePointerValue": {"UnsafePointer"}, +} + +var isReflectValue = map[string]bool{ + "reflect.ArrayOrSliceValue": true, // interface, not pointer + "*reflect.ArrayValue": true, + "*reflect.BoolValue": true, + "*reflect.ChanValue": true, + "*reflect.ComplexValue": true, + "*reflect.FloatValue": true, + "*reflect.FuncValue": true, + "*reflect.IntValue": true, + "*reflect.InterfaceValue": true, + "*reflect.MapValue": true, + "*reflect.PtrValue": true, + "*reflect.SliceValue": true, + "*reflect.StringValue": true, + "*reflect.StructValue": true, + "*reflect.UintValue": true, + "*reflect.UnsafePointerValue": true, + "reflect.Value": true, // interface, not pointer +} diff --git a/src/cmd/gofix/reflect_test.go b/src/cmd/gofix/reflect_test.go new file mode 100644 index 000000000..00edf30e9 --- /dev/null +++ b/src/cmd/gofix/reflect_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "io/ioutil" + "log" + "path/filepath" +) + +func init() { + addTestCases(reflectTests()) +} + +func reflectTests() []testCase { + var tests []testCase + + names, _ := filepath.Glob("testdata/reflect.*.in") + for _, in := range names { + out := in[:len(in)-len(".in")] + ".out" + inb, err := ioutil.ReadFile(in) + if err != nil { + log.Fatal(err) + } + outb, err := ioutil.ReadFile(out) + if err != nil { + log.Fatal(err) + } + tests = append(tests, testCase{Name: in, In: string(inb), Out: string(outb)}) + } + + return tests +} diff --git a/src/cmd/gofix/testdata/reflect.asn1.go.in b/src/cmd/gofix/testdata/reflect.asn1.go.in new file mode 100644 index 000000000..c5314517b --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.asn1.go.in @@ -0,0 +1,815 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The asn1 package implements parsing of DER-encoded ASN.1 data structures, +// as defined in ITU-T Rec X.690. +// +// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' +// http://luca.ntop.org/Teaching/Appunti/asn1.html. +package asn1 + +// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc +// are different encoding formats for those objects. Here, we'll be dealing +// with DER, the Distinguished Encoding Rules. DER is used in X.509 because +// it's fast to parse and, unlike BER, has a unique encoding for every object. +// When calculating hashes over objects, it's important that the resulting +// bytes be the same at both ends and DER removes this margin of error. +// +// ASN.1 is very complex and this package doesn't attempt to implement +// everything by any means. + +import ( + "fmt" + "os" + "reflect" + "time" +) + +// A StructuralError suggests that the ASN.1 data is valid, but the Go type +// which is receiving it doesn't match. +type StructuralError struct { + Msg string +} + +func (e StructuralError) String() string { return "ASN.1 structure error: " + e.Msg } + +// A SyntaxError suggests that the ASN.1 data is invalid. +type SyntaxError struct { + Msg string +} + +func (e SyntaxError) String() string { return "ASN.1 syntax error: " + e.Msg } + +// We start by dealing with each of the primitive types in turn. + +// BOOLEAN + +func parseBool(bytes []byte) (ret bool, err os.Error) { + if len(bytes) != 1 { + err = SyntaxError{"invalid boolean"} + return + } + + return bytes[0] != 0, nil +} + +// INTEGER + +// parseInt64 treats the given bytes as a big-endian, signed integer and +// returns the result. +func parseInt64(bytes []byte) (ret int64, err os.Error) { + if len(bytes) > 8 { + // We'll overflow an int64 in this case. + err = StructuralError{"integer too large"} + return + } + for bytesRead := 0; bytesRead < len(bytes); bytesRead++ { + ret <<= 8 + ret |= int64(bytes[bytesRead]) + } + + // Shift up and down in order to sign extend the result. + ret <<= 64 - uint8(len(bytes))*8 + ret >>= 64 - uint8(len(bytes))*8 + return +} + +// parseInt treats the given bytes as a big-endian, signed integer and returns +// the result. +func parseInt(bytes []byte) (int, os.Error) { + ret64, err := parseInt64(bytes) + if err != nil { + return 0, err + } + if ret64 != int64(int(ret64)) { + return 0, StructuralError{"integer too large"} + } + return int(ret64), nil +} + +// BIT STRING + +// BitString is the structure to use when you want an ASN.1 BIT STRING type. A +// bit string is padded up to the nearest byte in memory and the number of +// valid bits is recorded. Padding bits will be zero. +type BitString struct { + Bytes []byte // bits packed into bytes. + BitLength int // length in bits. +} + +// At returns the bit at the given index. If the index is out of range it +// returns false. +func (b BitString) At(i int) int { + if i < 0 || i >= b.BitLength { + return 0 + } + x := i / 8 + y := 7 - uint(i%8) + return int(b.Bytes[x]>>y) & 1 +} + +// RightAlign returns a slice where the padding bits are at the beginning. The +// slice may share memory with the BitString. +func (b BitString) RightAlign() []byte { + shift := uint(8 - (b.BitLength % 8)) + if shift == 8 || len(b.Bytes) == 0 { + return b.Bytes + } + + a := make([]byte, len(b.Bytes)) + a[0] = b.Bytes[0] >> shift + for i := 1; i < len(b.Bytes); i++ { + a[i] = b.Bytes[i-1] << (8 - shift) + a[i] |= b.Bytes[i] >> shift + } + + return a +} + +// parseBitString parses an ASN.1 bit string from the given byte array and returns it. +func parseBitString(bytes []byte) (ret BitString, err os.Error) { + if len(bytes) == 0 { + err = SyntaxError{"zero length BIT STRING"} + return + } + paddingBits := int(bytes[0]) + if paddingBits > 7 || + len(bytes) == 1 && paddingBits > 0 || + bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 { + err = SyntaxError{"invalid padding bits in BIT STRING"} + return + } + ret.BitLength = (len(bytes)-1)*8 - paddingBits + ret.Bytes = bytes[1:] + return +} + +// OBJECT IDENTIFIER + +// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER. +type ObjectIdentifier []int + +// Equal returns true iff oi and other represent the same identifier. +func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool { + if len(oi) != len(other) { + return false + } + for i := 0; i < len(oi); i++ { + if oi[i] != other[i] { + return false + } + } + + return true +} + +// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and +// returns it. An object identifer is a sequence of variable length integers +// that are assigned in a hierarachy. +func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) { + if len(bytes) == 0 { + err = SyntaxError{"zero length OBJECT IDENTIFIER"} + return + } + + // In the worst case, we get two elements from the first byte (which is + // encoded differently) and then every varint is a single byte long. + s = make([]int, len(bytes)+1) + + // The first byte is 40*value1 + value2: + s[0] = int(bytes[0]) / 40 + s[1] = int(bytes[0]) % 40 + i := 2 + for offset := 1; offset < len(bytes); i++ { + var v int + v, offset, err = parseBase128Int(bytes, offset) + if err != nil { + return + } + s[i] = v + } + s = s[0:i] + return +} + +// ENUMERATED + +// An Enumerated is represented as a plain int. +type Enumerated int + + +// FLAG + +// A Flag accepts any data and is set to true if present. +type Flag bool + +// parseBase128Int parses a base-128 encoded int from the given offset in the +// given byte array. It returns the value and the new offset. +func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) { + offset = initOffset + for shifted := 0; offset < len(bytes); shifted++ { + if shifted > 4 { + err = StructuralError{"base 128 integer too large"} + return + } + ret <<= 7 + b := bytes[offset] + ret |= int(b & 0x7f) + offset++ + if b&0x80 == 0 { + return + } + } + err = SyntaxError{"truncated base 128 integer"} + return +} + +// UTCTime + +func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) { + s := string(bytes) + ret, err = time.Parse("0601021504Z0700", s) + if err == nil { + return + } + ret, err = time.Parse("060102150405Z0700", s) + return +} + +// parseGeneralizedTime parses the GeneralizedTime from the given byte array +// and returns the resulting time. +func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) { + return time.Parse("20060102150405Z0700", string(bytes)) +} + +// PrintableString + +// parsePrintableString parses a ASN.1 PrintableString from the given byte +// array and returns it. +func parsePrintableString(bytes []byte) (ret string, err os.Error) { + for _, b := range bytes { + if !isPrintable(b) { + err = SyntaxError{"PrintableString contains invalid character"} + return + } + } + ret = string(bytes) + return +} + +// isPrintable returns true iff the given b is in the ASN.1 PrintableString set. +func isPrintable(b byte) bool { + return 'a' <= b && b <= 'z' || + 'A' <= b && b <= 'Z' || + '0' <= b && b <= '9' || + '\'' <= b && b <= ')' || + '+' <= b && b <= '/' || + b == ' ' || + b == ':' || + b == '=' || + b == '?' || + // This is techincally not allowed in a PrintableString. + // However, x509 certificates with wildcard strings don't + // always use the correct string type so we permit it. + b == '*' +} + +// IA5String + +// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given +// byte array and returns it. +func parseIA5String(bytes []byte) (ret string, err os.Error) { + for _, b := range bytes { + if b >= 0x80 { + err = SyntaxError{"IA5String contains invalid character"} + return + } + } + ret = string(bytes) + return +} + +// T61String + +// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given +// byte array and returns it. +func parseT61String(bytes []byte) (ret string, err os.Error) { + return string(bytes), nil +} + +// A RawValue represents an undecoded ASN.1 object. +type RawValue struct { + Class, Tag int + IsCompound bool + Bytes []byte + FullBytes []byte // includes the tag and length +} + +// RawContent is used to signal that the undecoded, DER data needs to be +// preserved for a struct. To use it, the first field of the struct must have +// this type. It's an error for any of the other fields to have this type. +type RawContent []byte + +// Tagging + +// parseTagAndLength parses an ASN.1 tag and length pair from the given offset +// into a byte array. It returns the parsed data and the new offset. SET and +// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we +// don't distinguish between ordered and unordered objects in this code. +func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) { + offset = initOffset + b := bytes[offset] + offset++ + ret.class = int(b >> 6) + ret.isCompound = b&0x20 == 0x20 + ret.tag = int(b & 0x1f) + + // If the bottom five bits are set, then the tag number is actually base 128 + // encoded afterwards + if ret.tag == 0x1f { + ret.tag, offset, err = parseBase128Int(bytes, offset) + if err != nil { + return + } + } + if offset >= len(bytes) { + err = SyntaxError{"truncated tag or length"} + return + } + b = bytes[offset] + offset++ + if b&0x80 == 0 { + // The length is encoded in the bottom 7 bits. + ret.length = int(b & 0x7f) + } else { + // Bottom 7 bits give the number of length bytes to follow. + numBytes := int(b & 0x7f) + // We risk overflowing a signed 32-bit number if we accept more than 3 bytes. + if numBytes > 3 { + err = StructuralError{"length too large"} + return + } + if numBytes == 0 { + err = SyntaxError{"indefinite length found (not DER)"} + return + } + ret.length = 0 + for i := 0; i < numBytes; i++ { + if offset >= len(bytes) { + err = SyntaxError{"truncated tag or length"} + return + } + b = bytes[offset] + offset++ + ret.length <<= 8 + ret.length |= int(b) + } + } + + return +} + +// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse +// a number of ASN.1 values from the given byte array and returns them as a +// slice of Go values of the given type. +func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflect.Type) (ret *reflect.SliceValue, err os.Error) { + expectedTag, compoundType, ok := getUniversalType(elemType) + if !ok { + err = StructuralError{"unknown Go type for slice"} + return + } + + // First we iterate over the input and count the number of elements, + // checking that the types are correct in each case. + numElements := 0 + for offset := 0; offset < len(bytes); { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + // We pretend that GENERAL STRINGs are PRINTABLE STRINGs so + // that a sequence of them can be parsed into a []string. + if t.tag == tagGeneralString { + t.tag = tagPrintableString + } + if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag { + err = StructuralError{"sequence tag mismatch"} + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"truncated sequence"} + return + } + offset += t.length + numElements++ + } + ret = reflect.MakeSlice(sliceType, numElements, numElements) + params := fieldParameters{} + offset := 0 + for i := 0; i < numElements; i++ { + offset, err = parseField(ret.Elem(i), bytes, offset, params) + if err != nil { + return + } + } + return +} + +var ( + bitStringType = reflect.Typeof(BitString{}) + objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) + enumeratedType = reflect.Typeof(Enumerated(0)) + flagType = reflect.Typeof(Flag(false)) + timeType = reflect.Typeof(&time.Time{}) + rawValueType = reflect.Typeof(RawValue{}) + rawContentsType = reflect.Typeof(RawContent(nil)) +) + +// invalidLength returns true iff offset + length > sliceLength, or if the +// addition would overflow. +func invalidLength(offset, length, sliceLength int) bool { + return offset+length < offset || offset+length > sliceLength +} + +// parseField is the main parsing function. Given a byte array and an offset +// into the array, it will try to parse a suitable ASN.1 value out and store it +// in the given Value. +func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) { + offset = initOffset + fieldType := v.Type() + + // If we have run out of data, it may be that there are optional elements at the end. + if offset == len(bytes) { + if !setDefaultValue(v, params) { + err = SyntaxError{"sequence truncated"} + } + return + } + + // Deal with raw values. + if fieldType == rawValueType { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} + offset += t.length + v.(*reflect.StructValue).Set(reflect.NewValue(result).(*reflect.StructValue)) + return + } + + // Deal with the ANY type. + if ifaceType, ok := fieldType.(*reflect.InterfaceType); ok && ifaceType.NumMethod() == 0 { + ifaceValue := v.(*reflect.InterfaceValue) + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + var result interface{} + if !t.isCompound && t.class == classUniversal { + innerBytes := bytes[offset : offset+t.length] + switch t.tag { + case tagPrintableString: + result, err = parsePrintableString(innerBytes) + case tagIA5String: + result, err = parseIA5String(innerBytes) + case tagT61String: + result, err = parseT61String(innerBytes) + case tagInteger: + result, err = parseInt64(innerBytes) + case tagBitString: + result, err = parseBitString(innerBytes) + case tagOID: + result, err = parseObjectIdentifier(innerBytes) + case tagUTCTime: + result, err = parseUTCTime(innerBytes) + case tagOctetString: + result = innerBytes + default: + // If we don't know how to handle the type, we just leave Value as nil. + } + } + offset += t.length + if err != nil { + return + } + if result != nil { + ifaceValue.Set(reflect.NewValue(result)) + } + return + } + universalTag, compoundType, ok1 := getUniversalType(fieldType) + if !ok1 { + err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)} + return + } + + t, offset, err := parseTagAndLength(bytes, offset) + if err != nil { + return + } + if params.explicit { + expectedClass := classContextSpecific + if params.application { + expectedClass = classApplication + } + if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) { + if t.length > 0 { + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + } else { + if fieldType != flagType { + err = StructuralError{"Zero length explicit tag was not an asn1.Flag"} + return + } + + flagValue := v.(*reflect.BoolValue) + flagValue.Set(true) + return + } + } else { + // The tags didn't match, it might be an optional element. + ok := setDefaultValue(v, params) + if ok { + offset = initOffset + } else { + err = StructuralError{"explicitly tagged member didn't match"} + } + return + } + } + + // Special case for strings: PrintableString and IA5String both map to + // the Go type string. getUniversalType returns the tag for + // PrintableString when it sees a string so, if we see an IA5String on + // the wire, we change the universal type to match. + if universalTag == tagPrintableString && t.tag == tagIA5String { + universalTag = tagIA5String + } + // Likewise for GeneralString + if universalTag == tagPrintableString && t.tag == tagGeneralString { + universalTag = tagGeneralString + } + + // Special case for time: UTCTime and GeneralizedTime both map to the + // Go type time.Time. + if universalTag == tagUTCTime && t.tag == tagGeneralizedTime { + universalTag = tagGeneralizedTime + } + + expectedClass := classUniversal + expectedTag := universalTag + + if !params.explicit && params.tag != nil { + expectedClass = classContextSpecific + expectedTag = *params.tag + } + + if !params.explicit && params.application && params.tag != nil { + expectedClass = classApplication + expectedTag = *params.tag + } + + // We have unwrapped any explicit tagging at this point. + if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType { + // Tags don't match. Again, it could be an optional element. + ok := setDefaultValue(v, params) + if ok { + offset = initOffset + } else { + err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)} + } + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + innerBytes := bytes[offset : offset+t.length] + offset += t.length + + // We deal with the structures defined in this package first. + switch fieldType { + case objectIdentifierType: + newSlice, err1 := parseObjectIdentifier(innerBytes) + sliceValue := v.(*reflect.SliceValue) + sliceValue.Set(reflect.MakeSlice(sliceValue.Type().(*reflect.SliceType), len(newSlice), len(newSlice))) + if err1 == nil { + reflect.Copy(sliceValue, reflect.NewValue(newSlice).(reflect.ArrayOrSliceValue)) + } + err = err1 + return + case bitStringType: + structValue := v.(*reflect.StructValue) + bs, err1 := parseBitString(innerBytes) + if err1 == nil { + structValue.Set(reflect.NewValue(bs).(*reflect.StructValue)) + } + err = err1 + return + case timeType: + ptrValue := v.(*reflect.PtrValue) + var time *time.Time + var err1 os.Error + if universalTag == tagUTCTime { + time, err1 = parseUTCTime(innerBytes) + } else { + time, err1 = parseGeneralizedTime(innerBytes) + } + if err1 == nil { + ptrValue.Set(reflect.NewValue(time).(*reflect.PtrValue)) + } + err = err1 + return + case enumeratedType: + parsedInt, err1 := parseInt(innerBytes) + enumValue := v.(*reflect.IntValue) + if err1 == nil { + enumValue.Set(int64(parsedInt)) + } + err = err1 + return + case flagType: + flagValue := v.(*reflect.BoolValue) + flagValue.Set(true) + return + } + switch val := v.(type) { + case *reflect.BoolValue: + parsedBool, err1 := parseBool(innerBytes) + if err1 == nil { + val.Set(parsedBool) + } + err = err1 + return + case *reflect.IntValue: + switch val.Type().Kind() { + case reflect.Int: + parsedInt, err1 := parseInt(innerBytes) + if err1 == nil { + val.Set(int64(parsedInt)) + } + err = err1 + return + case reflect.Int64: + parsedInt, err1 := parseInt64(innerBytes) + if err1 == nil { + val.Set(parsedInt) + } + err = err1 + return + } + case *reflect.StructValue: + structType := fieldType.(*reflect.StructType) + + if structType.NumField() > 0 && + structType.Field(0).Type == rawContentsType { + bytes := bytes[initOffset:offset] + val.Field(0).SetValue(reflect.NewValue(RawContent(bytes))) + } + + innerOffset := 0 + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if i == 0 && field.Type == rawContentsType { + continue + } + innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag)) + if err != nil { + return + } + } + // We allow extra bytes at the end of the SEQUENCE because + // adding elements to the end has been used in X.509 as the + // version numbers have increased. + return + case *reflect.SliceValue: + sliceType := fieldType.(*reflect.SliceType) + if sliceType.Elem().Kind() == reflect.Uint8 { + val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) + reflect.Copy(val, reflect.NewValue(innerBytes).(reflect.ArrayOrSliceValue)) + return + } + newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) + if err1 == nil { + val.Set(newSlice) + } + err = err1 + return + case *reflect.StringValue: + var v string + switch universalTag { + case tagPrintableString: + v, err = parsePrintableString(innerBytes) + case tagIA5String: + v, err = parseIA5String(innerBytes) + case tagT61String: + v, err = parseT61String(innerBytes) + case tagGeneralString: + // GeneralString is specified in ISO-2022/ECMA-35, + // A brief review suggests that it includes structures + // that allow the encoding to change midstring and + // such. We give up and pass it as an 8-bit string. + v, err = parseT61String(innerBytes) + default: + err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)} + } + if err == nil { + val.Set(v) + } + return + } + err = StructuralError{"unknown Go type"} + return +} + +// setDefaultValue is used to install a default value, from a tag string, into +// a Value. It is successful is the field was optional, even if a default value +// wasn't provided or it failed to install it into the Value. +func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { + if !params.optional { + return + } + ok = true + if params.defaultValue == nil { + return + } + switch val := v.(type) { + case *reflect.IntValue: + val.Set(*params.defaultValue) + } + return +} + +// Unmarshal parses the DER-encoded ASN.1 data structure b +// and uses the reflect package to fill in an arbitrary value pointed at by val. +// Because Unmarshal uses the reflect package, the structs +// being written to must use upper case field names. +// +// An ASN.1 INTEGER can be written to an int or int64. +// If the encoded value does not fit in the Go type, +// Unmarshal returns a parse error. +// +// An ASN.1 BIT STRING can be written to a BitString. +// +// An ASN.1 OCTET STRING can be written to a []byte. +// +// An ASN.1 OBJECT IDENTIFIER can be written to an +// ObjectIdentifier. +// +// An ASN.1 ENUMERATED can be written to an Enumerated. +// +// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time. +// +// An ASN.1 PrintableString or IA5String can be written to a string. +// +// Any of the above ASN.1 values can be written to an interface{}. +// The value stored in the interface has the corresponding Go type. +// For integers, that type is int64. +// +// An ASN.1 SEQUENCE OF x or SET OF x can be written +// to a slice if an x can be written to the slice's element type. +// +// An ASN.1 SEQUENCE or SET can be written to a struct +// if each of the elements in the sequence can be +// written to the corresponding element in the struct. +// +// The following tags on struct fields have special meaning to Unmarshal: +// +// optional marks the field as ASN.1 OPTIONAL +// [explicit] tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC +// default:x sets the default value for optional integer fields +// +// If the type of the first field of a structure is RawContent then the raw +// ASN1 contents of the struct will be stored in it. +// +// Other ASN.1 types are not supported; if it encounters them, +// Unmarshal returns a parse error. +func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { + return UnmarshalWithParams(b, val, "") +} + +// UnmarshalWithParams allows field parameters to be specified for the +// top-level element. The form of the params is the same as the field tags. +func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { + v := reflect.NewValue(val).(*reflect.PtrValue).Elem() + offset, err := parseField(v, b, 0, parseFieldParameters(params)) + if err != nil { + return nil, err + } + return b[offset:], nil +} diff --git a/src/cmd/gofix/testdata/reflect.asn1.go.out b/src/cmd/gofix/testdata/reflect.asn1.go.out new file mode 100644 index 000000000..902635939 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.asn1.go.out @@ -0,0 +1,815 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The asn1 package implements parsing of DER-encoded ASN.1 data structures, +// as defined in ITU-T Rec X.690. +// +// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' +// http://luca.ntop.org/Teaching/Appunti/asn1.html. +package asn1 + +// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc +// are different encoding formats for those objects. Here, we'll be dealing +// with DER, the Distinguished Encoding Rules. DER is used in X.509 because +// it's fast to parse and, unlike BER, has a unique encoding for every object. +// When calculating hashes over objects, it's important that the resulting +// bytes be the same at both ends and DER removes this margin of error. +// +// ASN.1 is very complex and this package doesn't attempt to implement +// everything by any means. + +import ( + "fmt" + "os" + "reflect" + "time" +) + +// A StructuralError suggests that the ASN.1 data is valid, but the Go type +// which is receiving it doesn't match. +type StructuralError struct { + Msg string +} + +func (e StructuralError) String() string { return "ASN.1 structure error: " + e.Msg } + +// A SyntaxError suggests that the ASN.1 data is invalid. +type SyntaxError struct { + Msg string +} + +func (e SyntaxError) String() string { return "ASN.1 syntax error: " + e.Msg } + +// We start by dealing with each of the primitive types in turn. + +// BOOLEAN + +func parseBool(bytes []byte) (ret bool, err os.Error) { + if len(bytes) != 1 { + err = SyntaxError{"invalid boolean"} + return + } + + return bytes[0] != 0, nil +} + +// INTEGER + +// parseInt64 treats the given bytes as a big-endian, signed integer and +// returns the result. +func parseInt64(bytes []byte) (ret int64, err os.Error) { + if len(bytes) > 8 { + // We'll overflow an int64 in this case. + err = StructuralError{"integer too large"} + return + } + for bytesRead := 0; bytesRead < len(bytes); bytesRead++ { + ret <<= 8 + ret |= int64(bytes[bytesRead]) + } + + // Shift up and down in order to sign extend the result. + ret <<= 64 - uint8(len(bytes))*8 + ret >>= 64 - uint8(len(bytes))*8 + return +} + +// parseInt treats the given bytes as a big-endian, signed integer and returns +// the result. +func parseInt(bytes []byte) (int, os.Error) { + ret64, err := parseInt64(bytes) + if err != nil { + return 0, err + } + if ret64 != int64(int(ret64)) { + return 0, StructuralError{"integer too large"} + } + return int(ret64), nil +} + +// BIT STRING + +// BitString is the structure to use when you want an ASN.1 BIT STRING type. A +// bit string is padded up to the nearest byte in memory and the number of +// valid bits is recorded. Padding bits will be zero. +type BitString struct { + Bytes []byte // bits packed into bytes. + BitLength int // length in bits. +} + +// At returns the bit at the given index. If the index is out of range it +// returns false. +func (b BitString) At(i int) int { + if i < 0 || i >= b.BitLength { + return 0 + } + x := i / 8 + y := 7 - uint(i%8) + return int(b.Bytes[x]>>y) & 1 +} + +// RightAlign returns a slice where the padding bits are at the beginning. The +// slice may share memory with the BitString. +func (b BitString) RightAlign() []byte { + shift := uint(8 - (b.BitLength % 8)) + if shift == 8 || len(b.Bytes) == 0 { + return b.Bytes + } + + a := make([]byte, len(b.Bytes)) + a[0] = b.Bytes[0] >> shift + for i := 1; i < len(b.Bytes); i++ { + a[i] = b.Bytes[i-1] << (8 - shift) + a[i] |= b.Bytes[i] >> shift + } + + return a +} + +// parseBitString parses an ASN.1 bit string from the given byte array and returns it. +func parseBitString(bytes []byte) (ret BitString, err os.Error) { + if len(bytes) == 0 { + err = SyntaxError{"zero length BIT STRING"} + return + } + paddingBits := int(bytes[0]) + if paddingBits > 7 || + len(bytes) == 1 && paddingBits > 0 || + bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 { + err = SyntaxError{"invalid padding bits in BIT STRING"} + return + } + ret.BitLength = (len(bytes)-1)*8 - paddingBits + ret.Bytes = bytes[1:] + return +} + +// OBJECT IDENTIFIER + +// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER. +type ObjectIdentifier []int + +// Equal returns true iff oi and other represent the same identifier. +func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool { + if len(oi) != len(other) { + return false + } + for i := 0; i < len(oi); i++ { + if oi[i] != other[i] { + return false + } + } + + return true +} + +// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and +// returns it. An object identifer is a sequence of variable length integers +// that are assigned in a hierarachy. +func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) { + if len(bytes) == 0 { + err = SyntaxError{"zero length OBJECT IDENTIFIER"} + return + } + + // In the worst case, we get two elements from the first byte (which is + // encoded differently) and then every varint is a single byte long. + s = make([]int, len(bytes)+1) + + // The first byte is 40*value1 + value2: + s[0] = int(bytes[0]) / 40 + s[1] = int(bytes[0]) % 40 + i := 2 + for offset := 1; offset < len(bytes); i++ { + var v int + v, offset, err = parseBase128Int(bytes, offset) + if err != nil { + return + } + s[i] = v + } + s = s[0:i] + return +} + +// ENUMERATED + +// An Enumerated is represented as a plain int. +type Enumerated int + + +// FLAG + +// A Flag accepts any data and is set to true if present. +type Flag bool + +// parseBase128Int parses a base-128 encoded int from the given offset in the +// given byte array. It returns the value and the new offset. +func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) { + offset = initOffset + for shifted := 0; offset < len(bytes); shifted++ { + if shifted > 4 { + err = StructuralError{"base 128 integer too large"} + return + } + ret <<= 7 + b := bytes[offset] + ret |= int(b & 0x7f) + offset++ + if b&0x80 == 0 { + return + } + } + err = SyntaxError{"truncated base 128 integer"} + return +} + +// UTCTime + +func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) { + s := string(bytes) + ret, err = time.Parse("0601021504Z0700", s) + if err == nil { + return + } + ret, err = time.Parse("060102150405Z0700", s) + return +} + +// parseGeneralizedTime parses the GeneralizedTime from the given byte array +// and returns the resulting time. +func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) { + return time.Parse("20060102150405Z0700", string(bytes)) +} + +// PrintableString + +// parsePrintableString parses a ASN.1 PrintableString from the given byte +// array and returns it. +func parsePrintableString(bytes []byte) (ret string, err os.Error) { + for _, b := range bytes { + if !isPrintable(b) { + err = SyntaxError{"PrintableString contains invalid character"} + return + } + } + ret = string(bytes) + return +} + +// isPrintable returns true iff the given b is in the ASN.1 PrintableString set. +func isPrintable(b byte) bool { + return 'a' <= b && b <= 'z' || + 'A' <= b && b <= 'Z' || + '0' <= b && b <= '9' || + '\'' <= b && b <= ')' || + '+' <= b && b <= '/' || + b == ' ' || + b == ':' || + b == '=' || + b == '?' || + // This is techincally not allowed in a PrintableString. + // However, x509 certificates with wildcard strings don't + // always use the correct string type so we permit it. + b == '*' +} + +// IA5String + +// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given +// byte array and returns it. +func parseIA5String(bytes []byte) (ret string, err os.Error) { + for _, b := range bytes { + if b >= 0x80 { + err = SyntaxError{"IA5String contains invalid character"} + return + } + } + ret = string(bytes) + return +} + +// T61String + +// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given +// byte array and returns it. +func parseT61String(bytes []byte) (ret string, err os.Error) { + return string(bytes), nil +} + +// A RawValue represents an undecoded ASN.1 object. +type RawValue struct { + Class, Tag int + IsCompound bool + Bytes []byte + FullBytes []byte // includes the tag and length +} + +// RawContent is used to signal that the undecoded, DER data needs to be +// preserved for a struct. To use it, the first field of the struct must have +// this type. It's an error for any of the other fields to have this type. +type RawContent []byte + +// Tagging + +// parseTagAndLength parses an ASN.1 tag and length pair from the given offset +// into a byte array. It returns the parsed data and the new offset. SET and +// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we +// don't distinguish between ordered and unordered objects in this code. +func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) { + offset = initOffset + b := bytes[offset] + offset++ + ret.class = int(b >> 6) + ret.isCompound = b&0x20 == 0x20 + ret.tag = int(b & 0x1f) + + // If the bottom five bits are set, then the tag number is actually base 128 + // encoded afterwards + if ret.tag == 0x1f { + ret.tag, offset, err = parseBase128Int(bytes, offset) + if err != nil { + return + } + } + if offset >= len(bytes) { + err = SyntaxError{"truncated tag or length"} + return + } + b = bytes[offset] + offset++ + if b&0x80 == 0 { + // The length is encoded in the bottom 7 bits. + ret.length = int(b & 0x7f) + } else { + // Bottom 7 bits give the number of length bytes to follow. + numBytes := int(b & 0x7f) + // We risk overflowing a signed 32-bit number if we accept more than 3 bytes. + if numBytes > 3 { + err = StructuralError{"length too large"} + return + } + if numBytes == 0 { + err = SyntaxError{"indefinite length found (not DER)"} + return + } + ret.length = 0 + for i := 0; i < numBytes; i++ { + if offset >= len(bytes) { + err = SyntaxError{"truncated tag or length"} + return + } + b = bytes[offset] + offset++ + ret.length <<= 8 + ret.length |= int(b) + } + } + + return +} + +// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse +// a number of ASN.1 values from the given byte array and returns them as a +// slice of Go values of the given type. +func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) { + expectedTag, compoundType, ok := getUniversalType(elemType) + if !ok { + err = StructuralError{"unknown Go type for slice"} + return + } + + // First we iterate over the input and count the number of elements, + // checking that the types are correct in each case. + numElements := 0 + for offset := 0; offset < len(bytes); { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + // We pretend that GENERAL STRINGs are PRINTABLE STRINGs so + // that a sequence of them can be parsed into a []string. + if t.tag == tagGeneralString { + t.tag = tagPrintableString + } + if t.class != classUniversal || t.isCompound != compoundType || t.tag != expectedTag { + err = StructuralError{"sequence tag mismatch"} + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"truncated sequence"} + return + } + offset += t.length + numElements++ + } + ret = reflect.MakeSlice(sliceType, numElements, numElements) + params := fieldParameters{} + offset := 0 + for i := 0; i < numElements; i++ { + offset, err = parseField(ret.Index(i), bytes, offset, params) + if err != nil { + return + } + } + return +} + +var ( + bitStringType = reflect.Typeof(BitString{}) + objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) + enumeratedType = reflect.Typeof(Enumerated(0)) + flagType = reflect.Typeof(Flag(false)) + timeType = reflect.Typeof(&time.Time{}) + rawValueType = reflect.Typeof(RawValue{}) + rawContentsType = reflect.Typeof(RawContent(nil)) +) + +// invalidLength returns true iff offset + length > sliceLength, or if the +// addition would overflow. +func invalidLength(offset, length, sliceLength int) bool { + return offset+length < offset || offset+length > sliceLength +} + +// parseField is the main parsing function. Given a byte array and an offset +// into the array, it will try to parse a suitable ASN.1 value out and store it +// in the given Value. +func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) { + offset = initOffset + fieldType := v.Type() + + // If we have run out of data, it may be that there are optional elements at the end. + if offset == len(bytes) { + if !setDefaultValue(v, params) { + err = SyntaxError{"sequence truncated"} + } + return + } + + // Deal with raw values. + if fieldType == rawValueType { + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} + offset += t.length + v.Set(reflect.NewValue(result)) + return + } + + // Deal with the ANY type. + if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 { + ifaceValue := v + var t tagAndLength + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + var result interface{} + if !t.isCompound && t.class == classUniversal { + innerBytes := bytes[offset : offset+t.length] + switch t.tag { + case tagPrintableString: + result, err = parsePrintableString(innerBytes) + case tagIA5String: + result, err = parseIA5String(innerBytes) + case tagT61String: + result, err = parseT61String(innerBytes) + case tagInteger: + result, err = parseInt64(innerBytes) + case tagBitString: + result, err = parseBitString(innerBytes) + case tagOID: + result, err = parseObjectIdentifier(innerBytes) + case tagUTCTime: + result, err = parseUTCTime(innerBytes) + case tagOctetString: + result = innerBytes + default: + // If we don't know how to handle the type, we just leave Value as nil. + } + } + offset += t.length + if err != nil { + return + } + if result != nil { + ifaceValue.Set(reflect.NewValue(result)) + } + return + } + universalTag, compoundType, ok1 := getUniversalType(fieldType) + if !ok1 { + err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)} + return + } + + t, offset, err := parseTagAndLength(bytes, offset) + if err != nil { + return + } + if params.explicit { + expectedClass := classContextSpecific + if params.application { + expectedClass = classApplication + } + if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) { + if t.length > 0 { + t, offset, err = parseTagAndLength(bytes, offset) + if err != nil { + return + } + } else { + if fieldType != flagType { + err = StructuralError{"Zero length explicit tag was not an asn1.Flag"} + return + } + + flagValue := v + flagValue.SetBool(true) + return + } + } else { + // The tags didn't match, it might be an optional element. + ok := setDefaultValue(v, params) + if ok { + offset = initOffset + } else { + err = StructuralError{"explicitly tagged member didn't match"} + } + return + } + } + + // Special case for strings: PrintableString and IA5String both map to + // the Go type string. getUniversalType returns the tag for + // PrintableString when it sees a string so, if we see an IA5String on + // the wire, we change the universal type to match. + if universalTag == tagPrintableString && t.tag == tagIA5String { + universalTag = tagIA5String + } + // Likewise for GeneralString + if universalTag == tagPrintableString && t.tag == tagGeneralString { + universalTag = tagGeneralString + } + + // Special case for time: UTCTime and GeneralizedTime both map to the + // Go type time.Time. + if universalTag == tagUTCTime && t.tag == tagGeneralizedTime { + universalTag = tagGeneralizedTime + } + + expectedClass := classUniversal + expectedTag := universalTag + + if !params.explicit && params.tag != nil { + expectedClass = classContextSpecific + expectedTag = *params.tag + } + + if !params.explicit && params.application && params.tag != nil { + expectedClass = classApplication + expectedTag = *params.tag + } + + // We have unwrapped any explicit tagging at this point. + if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType { + // Tags don't match. Again, it could be an optional element. + ok := setDefaultValue(v, params) + if ok { + offset = initOffset + } else { + err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)} + } + return + } + if invalidLength(offset, t.length, len(bytes)) { + err = SyntaxError{"data truncated"} + return + } + innerBytes := bytes[offset : offset+t.length] + offset += t.length + + // We deal with the structures defined in this package first. + switch fieldType { + case objectIdentifierType: + newSlice, err1 := parseObjectIdentifier(innerBytes) + sliceValue := v + sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(newSlice), len(newSlice))) + if err1 == nil { + reflect.Copy(sliceValue, reflect.NewValue(newSlice)) + } + err = err1 + return + case bitStringType: + structValue := v + bs, err1 := parseBitString(innerBytes) + if err1 == nil { + structValue.Set(reflect.NewValue(bs)) + } + err = err1 + return + case timeType: + ptrValue := v + var time *time.Time + var err1 os.Error + if universalTag == tagUTCTime { + time, err1 = parseUTCTime(innerBytes) + } else { + time, err1 = parseGeneralizedTime(innerBytes) + } + if err1 == nil { + ptrValue.Set(reflect.NewValue(time)) + } + err = err1 + return + case enumeratedType: + parsedInt, err1 := parseInt(innerBytes) + enumValue := v + if err1 == nil { + enumValue.SetInt(int64(parsedInt)) + } + err = err1 + return + case flagType: + flagValue := v + flagValue.SetBool(true) + return + } + switch val := v; val.Kind() { + case reflect.Bool: + parsedBool, err1 := parseBool(innerBytes) + if err1 == nil { + val.SetBool(parsedBool) + } + err = err1 + return + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch val.Type().Kind() { + case reflect.Int: + parsedInt, err1 := parseInt(innerBytes) + if err1 == nil { + val.SetInt(int64(parsedInt)) + } + err = err1 + return + case reflect.Int64: + parsedInt, err1 := parseInt64(innerBytes) + if err1 == nil { + val.SetInt(parsedInt) + } + err = err1 + return + } + case reflect.Struct: + structType := fieldType + + if structType.NumField() > 0 && + structType.Field(0).Type == rawContentsType { + bytes := bytes[initOffset:offset] + val.Field(0).Set(reflect.NewValue(RawContent(bytes))) + } + + innerOffset := 0 + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if i == 0 && field.Type == rawContentsType { + continue + } + innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag)) + if err != nil { + return + } + } + // We allow extra bytes at the end of the SEQUENCE because + // adding elements to the end has been used in X.509 as the + // version numbers have increased. + return + case reflect.Slice: + sliceType := fieldType + if sliceType.Elem().Kind() == reflect.Uint8 { + val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) + reflect.Copy(val, reflect.NewValue(innerBytes)) + return + } + newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) + if err1 == nil { + val.Set(newSlice) + } + err = err1 + return + case reflect.String: + var v string + switch universalTag { + case tagPrintableString: + v, err = parsePrintableString(innerBytes) + case tagIA5String: + v, err = parseIA5String(innerBytes) + case tagT61String: + v, err = parseT61String(innerBytes) + case tagGeneralString: + // GeneralString is specified in ISO-2022/ECMA-35, + // A brief review suggests that it includes structures + // that allow the encoding to change midstring and + // such. We give up and pass it as an 8-bit string. + v, err = parseT61String(innerBytes) + default: + err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)} + } + if err == nil { + val.SetString(v) + } + return + } + err = StructuralError{"unknown Go type"} + return +} + +// setDefaultValue is used to install a default value, from a tag string, into +// a Value. It is successful is the field was optional, even if a default value +// wasn't provided or it failed to install it into the Value. +func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { + if !params.optional { + return + } + ok = true + if params.defaultValue == nil { + return + } + switch val := v; val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val.SetInt(*params.defaultValue) + } + return +} + +// Unmarshal parses the DER-encoded ASN.1 data structure b +// and uses the reflect package to fill in an arbitrary value pointed at by val. +// Because Unmarshal uses the reflect package, the structs +// being written to must use upper case field names. +// +// An ASN.1 INTEGER can be written to an int or int64. +// If the encoded value does not fit in the Go type, +// Unmarshal returns a parse error. +// +// An ASN.1 BIT STRING can be written to a BitString. +// +// An ASN.1 OCTET STRING can be written to a []byte. +// +// An ASN.1 OBJECT IDENTIFIER can be written to an +// ObjectIdentifier. +// +// An ASN.1 ENUMERATED can be written to an Enumerated. +// +// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time. +// +// An ASN.1 PrintableString or IA5String can be written to a string. +// +// Any of the above ASN.1 values can be written to an interface{}. +// The value stored in the interface has the corresponding Go type. +// For integers, that type is int64. +// +// An ASN.1 SEQUENCE OF x or SET OF x can be written +// to a slice if an x can be written to the slice's element type. +// +// An ASN.1 SEQUENCE or SET can be written to a struct +// if each of the elements in the sequence can be +// written to the corresponding element in the struct. +// +// The following tags on struct fields have special meaning to Unmarshal: +// +// optional marks the field as ASN.1 OPTIONAL +// [explicit] tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC +// default:x sets the default value for optional integer fields +// +// If the type of the first field of a structure is RawContent then the raw +// ASN1 contents of the struct will be stored in it. +// +// Other ASN.1 types are not supported; if it encounters them, +// Unmarshal returns a parse error. +func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { + return UnmarshalWithParams(b, val, "") +} + +// UnmarshalWithParams allows field parameters to be specified for the +// top-level element. The form of the params is the same as the field tags. +func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { + v := reflect.NewValue(val).Elem() + offset, err := parseField(v, b, 0, parseFieldParameters(params)) + if err != nil { + return nil, err + } + return b[offset:], nil +} diff --git a/src/cmd/gofix/testdata/reflect.datafmt.go.in b/src/cmd/gofix/testdata/reflect.datafmt.go.in new file mode 100644 index 000000000..46c412342 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.datafmt.go.in @@ -0,0 +1,731 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* The datafmt package implements syntax-directed, type-driven formatting + of arbitrary data structures. Formatting a data structure consists of + two phases: first, a parser reads a format specification and builds a + "compiled" format. Then, the format can be applied repeatedly to + arbitrary values. Applying a format to a value evaluates to a []byte + containing the formatted value bytes, or nil. + + A format specification is a set of package declarations and format rules: + + Format = [ Entry { ";" Entry } [ ";" ] ] . + Entry = PackageDecl | FormatRule . + + (The syntax of a format specification is presented in the same EBNF + notation as used in the Go language specification. The syntax of white + space, comments, identifiers, and string literals is the same as in Go.) + + A package declaration binds a package name (such as 'ast') to a + package import path (such as '"go/ast"'). Each package used (in + a type name, see below) must be declared once before use. + + PackageDecl = PackageName ImportPath . + PackageName = identifier . + ImportPath = string . + + A format rule binds a rule name to a format expression. A rule name + may be a type name or one of the special names 'default' or '/'. + A type name may be the name of a predeclared type (for example, 'int', + 'float32', etc.), the package-qualified name of a user-defined type + (for example, 'ast.MapType'), or an identifier indicating the structure + of unnamed composite types ('array', 'chan', 'func', 'interface', 'map', + or 'ptr'). Each rule must have a unique name; rules can be declared in + any order. + + FormatRule = RuleName "=" Expression . + RuleName = TypeName | "default" | "/" . + TypeName = [ PackageName "." ] identifier . + + To format a value, the value's type name is used to select the format rule + (there is an override mechanism, see below). The format expression of the + selected rule specifies how the value is formatted. Each format expression, + when applied to a value, evaluates to a byte sequence or nil. + + In its most general form, a format expression is a list of alternatives, + each of which is a sequence of operands: + + Expression = [ Sequence ] { "|" [ Sequence ] } . + Sequence = Operand { Operand } . + + The formatted result produced by an expression is the result of the first + alternative sequence that evaluates to a non-nil result; if there is no + such alternative, the expression evaluates to nil. The result produced by + an operand sequence is the concatenation of the results of its operands. + If any operand in the sequence evaluates to nil, the entire sequence + evaluates to nil. + + There are five kinds of operands: + + Operand = Literal | Field | Group | Option | Repetition . + + Literals evaluate to themselves, with two substitutions. First, + %-formats expand in the manner of fmt.Printf, with the current value + passed as the parameter. Second, the current indentation (see below) + is inserted after every newline or form feed character. + + Literal = string . + + This table shows string literals applied to the value 42 and the + corresponding formatted result: + + "foo" foo + "%x" 2a + "x = %d" x = 42 + "%#x = %d" 0x2a = 42 + + A field operand is a field name optionally followed by an alternate + rule name. The field name may be an identifier or one of the special + names @ or *. + + Field = FieldName [ ":" RuleName ] . + FieldName = identifier | "@" | "*" . + + If the field name is an identifier, the current value must be a struct, + and there must be a field with that name in the struct. The same lookup + rules apply as in the Go language (for instance, the name of an anonymous + field is the unqualified type name). The field name denotes the field + value in the struct. If the field is not found, formatting is aborted + and an error message is returned. (TODO consider changing the semantics + such that if a field is not found, it evaluates to nil). + + The special name '@' denotes the current value. + + The meaning of the special name '*' depends on the type of the current + value: + + array, slice types array, slice element (inside {} only, see below) + interfaces value stored in interface + pointers value pointed to by pointer + + (Implementation restriction: channel, function and map types are not + supported due to missing reflection support). + + Fields are evaluated as follows: If the field value is nil, or an array + or slice element does not exist, the result is nil (see below for details + on array/slice elements). If the value is not nil the field value is + formatted (recursively) using the rule corresponding to its type name, + or the alternate rule name, if given. + + The following example shows a complete format specification for a + struct 'myPackage.Point'. Assume the package + + package myPackage // in directory myDir/myPackage + type Point struct { + name string; + x, y int; + } + + Applying the format specification + + myPackage "myDir/myPackage"; + int = "%d"; + hexInt = "0x%x"; + string = "---%s---"; + myPackage.Point = name "{" x ", " y:hexInt "}"; + + to the value myPackage.Point{"foo", 3, 15} results in + + ---foo---{3, 0xf} + + Finally, an operand may be a grouped, optional, or repeated expression. + A grouped expression ("group") groups a more complex expression (body) + so that it can be used in place of a single operand: + + Group = "(" [ Indentation ">>" ] Body ")" . + Indentation = Expression . + Body = Expression . + + A group body may be prefixed by an indentation expression followed by '>>'. + The indentation expression is applied to the current value like any other + expression and the result, if not nil, is appended to the current indentation + during the evaluation of the body (see also formatting state, below). + + An optional expression ("option") is enclosed in '[]' brackets. + + Option = "[" Body "]" . + + An option evaluates to its body, except that if the body evaluates to nil, + the option expression evaluates to an empty []byte. Thus an option's purpose + is to protect the expression containing the option from a nil operand. + + A repeated expression ("repetition") is enclosed in '{}' braces. + + Repetition = "{" Body [ "/" Separator ] "}" . + Separator = Expression . + + A repeated expression is evaluated as follows: The body is evaluated + repeatedly and its results are concatenated until the body evaluates + to nil. The result of the repetition is the (possibly empty) concatenation, + but it is never nil. An implicit index is supplied for the evaluation of + the body: that index is used to address elements of arrays or slices. If + the corresponding elements do not exist, the field denoting the element + evaluates to nil (which in turn may terminate the repetition). + + The body of a repetition may be followed by a '/' and a "separator" + expression. If the separator is present, it is invoked between repetitions + of the body. + + The following example shows a complete format specification for formatting + a slice of unnamed type. Applying the specification + + int = "%b"; + array = { * / ", " }; // array is the type name for an unnamed slice + + to the value '[]int{2, 3, 5, 7}' results in + + 10, 11, 101, 111 + + Default rule: If a format rule named 'default' is present, it is used for + formatting a value if no other rule was found. A common default rule is + + default = "%v" + + to provide default formatting for basic types without having to specify + a specific rule for each basic type. + + Global separator rule: If a format rule named '/' is present, it is + invoked with the current value between literals. If the separator + expression evaluates to nil, it is ignored. + + For instance, a global separator rule may be used to punctuate a sequence + of values with commas. The rules: + + default = "%v"; + / = ", "; + + will format an argument list by printing each one in its default format, + separated by a comma and a space. +*/ +package datafmt + +import ( + "bytes" + "fmt" + "go/token" + "io" + "os" + "reflect" + "runtime" +) + + +// ---------------------------------------------------------------------------- +// Format representation + +// Custom formatters implement the Formatter function type. +// A formatter is invoked with the current formatting state, the +// value to format, and the rule name under which the formatter +// was installed (the same formatter function may be installed +// under different names). The formatter may access the current state +// to guide formatting and use State.Write to append to the state's +// output. +// +// A formatter must return a boolean value indicating if it evaluated +// to a non-nil value (true), or a nil value (false). +// +type Formatter func(state *State, value interface{}, ruleName string) bool + + +// A FormatterMap is a set of custom formatters. +// It maps a rule name to a formatter function. +// +type FormatterMap map[string]Formatter + + +// A parsed format expression is built from the following nodes. +// +type ( + expr interface{} + + alternatives []expr // x | y | z + + sequence []expr // x y z + + literal [][]byte // a list of string segments, possibly starting with '%' + + field struct { + fieldName string // including "@", "*" + ruleName string // "" if no rule name specified + } + + group struct { + indent, body expr // (indent >> body) + } + + option struct { + body expr // [body] + } + + repetition struct { + body, separator expr // {body / separator} + } + + custom struct { + ruleName string + fun Formatter + } +) + + +// A Format is the result of parsing a format specification. +// The format may be applied repeatedly to format values. +// +type Format map[string]expr + + +// ---------------------------------------------------------------------------- +// Formatting + +// An application-specific environment may be provided to Format.Apply; +// the environment is available inside custom formatters via State.Env(). +// Environments must implement copying; the Copy method must return an +// complete copy of the receiver. This is necessary so that the formatter +// can save and restore an environment (in case of an absent expression). +// +// If the Environment doesn't change during formatting (this is under +// control of the custom formatters), the Copy function can simply return +// the receiver, and thus can be very light-weight. +// +type Environment interface { + Copy() Environment +} + + +// State represents the current formatting state. +// It is provided as argument to custom formatters. +// +type State struct { + fmt Format // format in use + env Environment // user-supplied environment + errors chan os.Error // not chan *Error (errors <- nil would be wrong!) + hasOutput bool // true after the first literal has been written + indent bytes.Buffer // current indentation + output bytes.Buffer // format output + linePos token.Position // position of line beginning (Column == 0) + default_ expr // possibly nil + separator expr // possibly nil +} + + +func newState(fmt Format, env Environment, errors chan os.Error) *State { + s := new(State) + s.fmt = fmt + s.env = env + s.errors = errors + s.linePos = token.Position{Line: 1} + + // if we have a default rule, cache it's expression for fast access + if x, found := fmt["default"]; found { + s.default_ = x + } + + // if we have a global separator rule, cache it's expression for fast access + if x, found := fmt["/"]; found { + s.separator = x + } + + return s +} + + +// Env returns the environment passed to Format.Apply. +func (s *State) Env() interface{} { return s.env } + + +// LinePos returns the position of the current line beginning +// in the state's output buffer. Line numbers start at 1. +// +func (s *State) LinePos() token.Position { return s.linePos } + + +// Pos returns the position of the next byte to be written to the +// output buffer. Line numbers start at 1. +// +func (s *State) Pos() token.Position { + offs := s.output.Len() + return token.Position{Line: s.linePos.Line, Column: offs - s.linePos.Offset, Offset: offs} +} + + +// Write writes data to the output buffer, inserting the indentation +// string after each newline or form feed character. It cannot return an error. +// +func (s *State) Write(data []byte) (int, os.Error) { + n := 0 + i0 := 0 + for i, ch := range data { + if ch == '\n' || ch == '\f' { + // write text segment and indentation + n1, _ := s.output.Write(data[i0 : i+1]) + n2, _ := s.output.Write(s.indent.Bytes()) + n += n1 + n2 + i0 = i + 1 + s.linePos.Offset = s.output.Len() + s.linePos.Line++ + } + } + n3, _ := s.output.Write(data[i0:]) + return n + n3, nil +} + + +type checkpoint struct { + env Environment + hasOutput bool + outputLen int + linePos token.Position +} + + +func (s *State) save() checkpoint { + saved := checkpoint{nil, s.hasOutput, s.output.Len(), s.linePos} + if s.env != nil { + saved.env = s.env.Copy() + } + return saved +} + + +func (s *State) restore(m checkpoint) { + s.env = m.env + s.output.Truncate(m.outputLen) +} + + +func (s *State) error(msg string) { + s.errors <- os.NewError(msg) + runtime.Goexit() +} + + +// TODO At the moment, unnamed types are simply mapped to the default +// names below. For instance, all unnamed arrays are mapped to +// 'array' which is not really sufficient. Eventually one may want +// to be able to specify rules for say an unnamed slice of T. +// + +func typename(typ reflect.Type) string { + switch typ.(type) { + case *reflect.ArrayType: + return "array" + case *reflect.SliceType: + return "array" + case *reflect.ChanType: + return "chan" + case *reflect.FuncType: + return "func" + case *reflect.InterfaceType: + return "interface" + case *reflect.MapType: + return "map" + case *reflect.PtrType: + return "ptr" + } + return typ.String() +} + +func (s *State) getFormat(name string) expr { + if fexpr, found := s.fmt[name]; found { + return fexpr + } + + if s.default_ != nil { + return s.default_ + } + + s.error(fmt.Sprintf("no format rule for type: '%s'", name)) + return nil +} + + +// eval applies a format expression fexpr to a value. If the expression +// evaluates internally to a non-nil []byte, that slice is appended to +// the state's output buffer and eval returns true. Otherwise, eval +// returns false and the state remains unchanged. +// +func (s *State) eval(fexpr expr, value reflect.Value, index int) bool { + // an empty format expression always evaluates + // to a non-nil (but empty) []byte + if fexpr == nil { + return true + } + + switch t := fexpr.(type) { + case alternatives: + // append the result of the first alternative that evaluates to + // a non-nil []byte to the state's output + mark := s.save() + for _, x := range t { + if s.eval(x, value, index) { + return true + } + s.restore(mark) + } + return false + + case sequence: + // append the result of all operands to the state's output + // unless a nil result is encountered + mark := s.save() + for _, x := range t { + if !s.eval(x, value, index) { + s.restore(mark) + return false + } + } + return true + + case literal: + // write separator, if any + if s.hasOutput { + // not the first literal + if s.separator != nil { + sep := s.separator // save current separator + s.separator = nil // and disable it (avoid recursion) + mark := s.save() + if !s.eval(sep, value, index) { + s.restore(mark) + } + s.separator = sep // enable it again + } + } + s.hasOutput = true + // write literal segments + for _, lit := range t { + if len(lit) > 1 && lit[0] == '%' { + // segment contains a %-format at the beginning + if lit[1] == '%' { + // "%%" is printed as a single "%" + s.Write(lit[1:]) + } else { + // use s instead of s.output to get indentation right + fmt.Fprintf(s, string(lit), value.Interface()) + } + } else { + // segment contains no %-formats + s.Write(lit) + } + } + return true // a literal never evaluates to nil + + case *field: + // determine field value + switch t.fieldName { + case "@": + // field value is current value + + case "*": + // indirection: operation is type-specific + switch v := value.(type) { + case *reflect.ArrayValue: + if v.Len() <= index { + return false + } + value = v.Elem(index) + + case *reflect.SliceValue: + if v.IsNil() || v.Len() <= index { + return false + } + value = v.Elem(index) + + case *reflect.MapValue: + s.error("reflection support for maps incomplete") + + case *reflect.PtrValue: + if v.IsNil() { + return false + } + value = v.Elem() + + case *reflect.InterfaceValue: + if v.IsNil() { + return false + } + value = v.Elem() + + case *reflect.ChanValue: + s.error("reflection support for chans incomplete") + + case *reflect.FuncValue: + s.error("reflection support for funcs incomplete") + + default: + s.error(fmt.Sprintf("error: * does not apply to `%s`", value.Type())) + } + + default: + // value is value of named field + var field reflect.Value + if sval, ok := value.(*reflect.StructValue); ok { + field = sval.FieldByName(t.fieldName) + if field == nil { + // TODO consider just returning false in this case + s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type())) + } + } + value = field + } + + // determine rule + ruleName := t.ruleName + if ruleName == "" { + // no alternate rule name, value type determines rule + ruleName = typename(value.Type()) + } + fexpr = s.getFormat(ruleName) + + mark := s.save() + if !s.eval(fexpr, value, index) { + s.restore(mark) + return false + } + return true + + case *group: + // remember current indentation + indentLen := s.indent.Len() + + // update current indentation + mark := s.save() + s.eval(t.indent, value, index) + // if the indentation evaluates to nil, the state's output buffer + // didn't change - either way it's ok to append the difference to + // the current identation + s.indent.Write(s.output.Bytes()[mark.outputLen:s.output.Len()]) + s.restore(mark) + + // format group body + mark = s.save() + b := true + if !s.eval(t.body, value, index) { + s.restore(mark) + b = false + } + + // reset indentation + s.indent.Truncate(indentLen) + return b + + case *option: + // evaluate the body and append the result to the state's output + // buffer unless the result is nil + mark := s.save() + if !s.eval(t.body, value, 0) { // TODO is 0 index correct? + s.restore(mark) + } + return true // an option never evaluates to nil + + case *repetition: + // evaluate the body and append the result to the state's output + // buffer until a result is nil + for i := 0; ; i++ { + mark := s.save() + // write separator, if any + if i > 0 && t.separator != nil { + // nil result from separator is ignored + mark := s.save() + if !s.eval(t.separator, value, i) { + s.restore(mark) + } + } + if !s.eval(t.body, value, i) { + s.restore(mark) + break + } + } + return true // a repetition never evaluates to nil + + case *custom: + // invoke the custom formatter to obtain the result + mark := s.save() + if !t.fun(s, value.Interface(), t.ruleName) { + s.restore(mark) + return false + } + return true + } + + panic("unreachable") + return false +} + + +// Eval formats each argument according to the format +// f and returns the resulting []byte and os.Error. If +// an error occurred, the []byte contains the partially +// formatted result. An environment env may be passed +// in which is available in custom formatters through +// the state parameter. +// +func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { + if f == nil { + return nil, os.NewError("format is nil") + } + + errors := make(chan os.Error) + s := newState(f, env, errors) + + go func() { + for _, v := range args { + fld := reflect.NewValue(v) + if fld == nil { + errors <- os.NewError("nil argument") + return + } + mark := s.save() + if !s.eval(s.getFormat(typename(fld.Type())), fld, 0) { // TODO is 0 index correct? + s.restore(mark) + } + } + errors <- nil // no errors + }() + + err := <-errors + return s.output.Bytes(), err +} + + +// ---------------------------------------------------------------------------- +// Convenience functions + +// Fprint formats each argument according to the format f +// and writes to w. The result is the total number of bytes +// written and an os.Error, if any. +// +func (f Format) Fprint(w io.Writer, env Environment, args ...interface{}) (int, os.Error) { + data, err := f.Eval(env, args...) + if err != nil { + // TODO should we print partial result in case of error? + return 0, err + } + return w.Write(data) +} + + +// Print formats each argument according to the format f +// and writes to standard output. The result is the total +// number of bytes written and an os.Error, if any. +// +func (f Format) Print(args ...interface{}) (int, os.Error) { + return f.Fprint(os.Stdout, nil, args...) +} + + +// Sprint formats each argument according to the format f +// and returns the resulting string. If an error occurs +// during formatting, the result string contains the +// partially formatted result followed by an error message. +// +func (f Format) Sprint(args ...interface{}) string { + var buf bytes.Buffer + _, err := f.Fprint(&buf, nil, args...) + if err != nil { + var i interface{} = args + fmt.Fprintf(&buf, "--- Sprint(%s) failed: %v", fmt.Sprint(i), err) + } + return buf.String() +} diff --git a/src/cmd/gofix/testdata/reflect.datafmt.go.out b/src/cmd/gofix/testdata/reflect.datafmt.go.out new file mode 100644 index 000000000..6d816fc2d --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.datafmt.go.out @@ -0,0 +1,731 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* The datafmt package implements syntax-directed, type-driven formatting + of arbitrary data structures. Formatting a data structure consists of + two phases: first, a parser reads a format specification and builds a + "compiled" format. Then, the format can be applied repeatedly to + arbitrary values. Applying a format to a value evaluates to a []byte + containing the formatted value bytes, or nil. + + A format specification is a set of package declarations and format rules: + + Format = [ Entry { ";" Entry } [ ";" ] ] . + Entry = PackageDecl | FormatRule . + + (The syntax of a format specification is presented in the same EBNF + notation as used in the Go language specification. The syntax of white + space, comments, identifiers, and string literals is the same as in Go.) + + A package declaration binds a package name (such as 'ast') to a + package import path (such as '"go/ast"'). Each package used (in + a type name, see below) must be declared once before use. + + PackageDecl = PackageName ImportPath . + PackageName = identifier . + ImportPath = string . + + A format rule binds a rule name to a format expression. A rule name + may be a type name or one of the special names 'default' or '/'. + A type name may be the name of a predeclared type (for example, 'int', + 'float32', etc.), the package-qualified name of a user-defined type + (for example, 'ast.MapType'), or an identifier indicating the structure + of unnamed composite types ('array', 'chan', 'func', 'interface', 'map', + or 'ptr'). Each rule must have a unique name; rules can be declared in + any order. + + FormatRule = RuleName "=" Expression . + RuleName = TypeName | "default" | "/" . + TypeName = [ PackageName "." ] identifier . + + To format a value, the value's type name is used to select the format rule + (there is an override mechanism, see below). The format expression of the + selected rule specifies how the value is formatted. Each format expression, + when applied to a value, evaluates to a byte sequence or nil. + + In its most general form, a format expression is a list of alternatives, + each of which is a sequence of operands: + + Expression = [ Sequence ] { "|" [ Sequence ] } . + Sequence = Operand { Operand } . + + The formatted result produced by an expression is the result of the first + alternative sequence that evaluates to a non-nil result; if there is no + such alternative, the expression evaluates to nil. The result produced by + an operand sequence is the concatenation of the results of its operands. + If any operand in the sequence evaluates to nil, the entire sequence + evaluates to nil. + + There are five kinds of operands: + + Operand = Literal | Field | Group | Option | Repetition . + + Literals evaluate to themselves, with two substitutions. First, + %-formats expand in the manner of fmt.Printf, with the current value + passed as the parameter. Second, the current indentation (see below) + is inserted after every newline or form feed character. + + Literal = string . + + This table shows string literals applied to the value 42 and the + corresponding formatted result: + + "foo" foo + "%x" 2a + "x = %d" x = 42 + "%#x = %d" 0x2a = 42 + + A field operand is a field name optionally followed by an alternate + rule name. The field name may be an identifier or one of the special + names @ or *. + + Field = FieldName [ ":" RuleName ] . + FieldName = identifier | "@" | "*" . + + If the field name is an identifier, the current value must be a struct, + and there must be a field with that name in the struct. The same lookup + rules apply as in the Go language (for instance, the name of an anonymous + field is the unqualified type name). The field name denotes the field + value in the struct. If the field is not found, formatting is aborted + and an error message is returned. (TODO consider changing the semantics + such that if a field is not found, it evaluates to nil). + + The special name '@' denotes the current value. + + The meaning of the special name '*' depends on the type of the current + value: + + array, slice types array, slice element (inside {} only, see below) + interfaces value stored in interface + pointers value pointed to by pointer + + (Implementation restriction: channel, function and map types are not + supported due to missing reflection support). + + Fields are evaluated as follows: If the field value is nil, or an array + or slice element does not exist, the result is nil (see below for details + on array/slice elements). If the value is not nil the field value is + formatted (recursively) using the rule corresponding to its type name, + or the alternate rule name, if given. + + The following example shows a complete format specification for a + struct 'myPackage.Point'. Assume the package + + package myPackage // in directory myDir/myPackage + type Point struct { + name string; + x, y int; + } + + Applying the format specification + + myPackage "myDir/myPackage"; + int = "%d"; + hexInt = "0x%x"; + string = "---%s---"; + myPackage.Point = name "{" x ", " y:hexInt "}"; + + to the value myPackage.Point{"foo", 3, 15} results in + + ---foo---{3, 0xf} + + Finally, an operand may be a grouped, optional, or repeated expression. + A grouped expression ("group") groups a more complex expression (body) + so that it can be used in place of a single operand: + + Group = "(" [ Indentation ">>" ] Body ")" . + Indentation = Expression . + Body = Expression . + + A group body may be prefixed by an indentation expression followed by '>>'. + The indentation expression is applied to the current value like any other + expression and the result, if not nil, is appended to the current indentation + during the evaluation of the body (see also formatting state, below). + + An optional expression ("option") is enclosed in '[]' brackets. + + Option = "[" Body "]" . + + An option evaluates to its body, except that if the body evaluates to nil, + the option expression evaluates to an empty []byte. Thus an option's purpose + is to protect the expression containing the option from a nil operand. + + A repeated expression ("repetition") is enclosed in '{}' braces. + + Repetition = "{" Body [ "/" Separator ] "}" . + Separator = Expression . + + A repeated expression is evaluated as follows: The body is evaluated + repeatedly and its results are concatenated until the body evaluates + to nil. The result of the repetition is the (possibly empty) concatenation, + but it is never nil. An implicit index is supplied for the evaluation of + the body: that index is used to address elements of arrays or slices. If + the corresponding elements do not exist, the field denoting the element + evaluates to nil (which in turn may terminate the repetition). + + The body of a repetition may be followed by a '/' and a "separator" + expression. If the separator is present, it is invoked between repetitions + of the body. + + The following example shows a complete format specification for formatting + a slice of unnamed type. Applying the specification + + int = "%b"; + array = { * / ", " }; // array is the type name for an unnamed slice + + to the value '[]int{2, 3, 5, 7}' results in + + 10, 11, 101, 111 + + Default rule: If a format rule named 'default' is present, it is used for + formatting a value if no other rule was found. A common default rule is + + default = "%v" + + to provide default formatting for basic types without having to specify + a specific rule for each basic type. + + Global separator rule: If a format rule named '/' is present, it is + invoked with the current value between literals. If the separator + expression evaluates to nil, it is ignored. + + For instance, a global separator rule may be used to punctuate a sequence + of values with commas. The rules: + + default = "%v"; + / = ", "; + + will format an argument list by printing each one in its default format, + separated by a comma and a space. +*/ +package datafmt + +import ( + "bytes" + "fmt" + "go/token" + "io" + "os" + "reflect" + "runtime" +) + + +// ---------------------------------------------------------------------------- +// Format representation + +// Custom formatters implement the Formatter function type. +// A formatter is invoked with the current formatting state, the +// value to format, and the rule name under which the formatter +// was installed (the same formatter function may be installed +// under different names). The formatter may access the current state +// to guide formatting and use State.Write to append to the state's +// output. +// +// A formatter must return a boolean value indicating if it evaluated +// to a non-nil value (true), or a nil value (false). +// +type Formatter func(state *State, value interface{}, ruleName string) bool + + +// A FormatterMap is a set of custom formatters. +// It maps a rule name to a formatter function. +// +type FormatterMap map[string]Formatter + + +// A parsed format expression is built from the following nodes. +// +type ( + expr interface{} + + alternatives []expr // x | y | z + + sequence []expr // x y z + + literal [][]byte // a list of string segments, possibly starting with '%' + + field struct { + fieldName string // including "@", "*" + ruleName string // "" if no rule name specified + } + + group struct { + indent, body expr // (indent >> body) + } + + option struct { + body expr // [body] + } + + repetition struct { + body, separator expr // {body / separator} + } + + custom struct { + ruleName string + fun Formatter + } +) + + +// A Format is the result of parsing a format specification. +// The format may be applied repeatedly to format values. +// +type Format map[string]expr + + +// ---------------------------------------------------------------------------- +// Formatting + +// An application-specific environment may be provided to Format.Apply; +// the environment is available inside custom formatters via State.Env(). +// Environments must implement copying; the Copy method must return an +// complete copy of the receiver. This is necessary so that the formatter +// can save and restore an environment (in case of an absent expression). +// +// If the Environment doesn't change during formatting (this is under +// control of the custom formatters), the Copy function can simply return +// the receiver, and thus can be very light-weight. +// +type Environment interface { + Copy() Environment +} + + +// State represents the current formatting state. +// It is provided as argument to custom formatters. +// +type State struct { + fmt Format // format in use + env Environment // user-supplied environment + errors chan os.Error // not chan *Error (errors <- nil would be wrong!) + hasOutput bool // true after the first literal has been written + indent bytes.Buffer // current indentation + output bytes.Buffer // format output + linePos token.Position // position of line beginning (Column == 0) + default_ expr // possibly nil + separator expr // possibly nil +} + + +func newState(fmt Format, env Environment, errors chan os.Error) *State { + s := new(State) + s.fmt = fmt + s.env = env + s.errors = errors + s.linePos = token.Position{Line: 1} + + // if we have a default rule, cache it's expression for fast access + if x, found := fmt["default"]; found { + s.default_ = x + } + + // if we have a global separator rule, cache it's expression for fast access + if x, found := fmt["/"]; found { + s.separator = x + } + + return s +} + + +// Env returns the environment passed to Format.Apply. +func (s *State) Env() interface{} { return s.env } + + +// LinePos returns the position of the current line beginning +// in the state's output buffer. Line numbers start at 1. +// +func (s *State) LinePos() token.Position { return s.linePos } + + +// Pos returns the position of the next byte to be written to the +// output buffer. Line numbers start at 1. +// +func (s *State) Pos() token.Position { + offs := s.output.Len() + return token.Position{Line: s.linePos.Line, Column: offs - s.linePos.Offset, Offset: offs} +} + + +// Write writes data to the output buffer, inserting the indentation +// string after each newline or form feed character. It cannot return an error. +// +func (s *State) Write(data []byte) (int, os.Error) { + n := 0 + i0 := 0 + for i, ch := range data { + if ch == '\n' || ch == '\f' { + // write text segment and indentation + n1, _ := s.output.Write(data[i0 : i+1]) + n2, _ := s.output.Write(s.indent.Bytes()) + n += n1 + n2 + i0 = i + 1 + s.linePos.Offset = s.output.Len() + s.linePos.Line++ + } + } + n3, _ := s.output.Write(data[i0:]) + return n + n3, nil +} + + +type checkpoint struct { + env Environment + hasOutput bool + outputLen int + linePos token.Position +} + + +func (s *State) save() checkpoint { + saved := checkpoint{nil, s.hasOutput, s.output.Len(), s.linePos} + if s.env != nil { + saved.env = s.env.Copy() + } + return saved +} + + +func (s *State) restore(m checkpoint) { + s.env = m.env + s.output.Truncate(m.outputLen) +} + + +func (s *State) error(msg string) { + s.errors <- os.NewError(msg) + runtime.Goexit() +} + + +// TODO At the moment, unnamed types are simply mapped to the default +// names below. For instance, all unnamed arrays are mapped to +// 'array' which is not really sufficient. Eventually one may want +// to be able to specify rules for say an unnamed slice of T. +// + +func typename(typ reflect.Type) string { + switch typ.Kind() { + case reflect.Array: + return "array" + case reflect.Slice: + return "array" + case reflect.Chan: + return "chan" + case reflect.Func: + return "func" + case reflect.Interface: + return "interface" + case reflect.Map: + return "map" + case reflect.Ptr: + return "ptr" + } + return typ.String() +} + +func (s *State) getFormat(name string) expr { + if fexpr, found := s.fmt[name]; found { + return fexpr + } + + if s.default_ != nil { + return s.default_ + } + + s.error(fmt.Sprintf("no format rule for type: '%s'", name)) + return nil +} + + +// eval applies a format expression fexpr to a value. If the expression +// evaluates internally to a non-nil []byte, that slice is appended to +// the state's output buffer and eval returns true. Otherwise, eval +// returns false and the state remains unchanged. +// +func (s *State) eval(fexpr expr, value reflect.Value, index int) bool { + // an empty format expression always evaluates + // to a non-nil (but empty) []byte + if fexpr == nil { + return true + } + + switch t := fexpr.(type) { + case alternatives: + // append the result of the first alternative that evaluates to + // a non-nil []byte to the state's output + mark := s.save() + for _, x := range t { + if s.eval(x, value, index) { + return true + } + s.restore(mark) + } + return false + + case sequence: + // append the result of all operands to the state's output + // unless a nil result is encountered + mark := s.save() + for _, x := range t { + if !s.eval(x, value, index) { + s.restore(mark) + return false + } + } + return true + + case literal: + // write separator, if any + if s.hasOutput { + // not the first literal + if s.separator != nil { + sep := s.separator // save current separator + s.separator = nil // and disable it (avoid recursion) + mark := s.save() + if !s.eval(sep, value, index) { + s.restore(mark) + } + s.separator = sep // enable it again + } + } + s.hasOutput = true + // write literal segments + for _, lit := range t { + if len(lit) > 1 && lit[0] == '%' { + // segment contains a %-format at the beginning + if lit[1] == '%' { + // "%%" is printed as a single "%" + s.Write(lit[1:]) + } else { + // use s instead of s.output to get indentation right + fmt.Fprintf(s, string(lit), value.Interface()) + } + } else { + // segment contains no %-formats + s.Write(lit) + } + } + return true // a literal never evaluates to nil + + case *field: + // determine field value + switch t.fieldName { + case "@": + // field value is current value + + case "*": + // indirection: operation is type-specific + switch v := value; v.Kind() { + case reflect.Array: + if v.Len() <= index { + return false + } + value = v.Index(index) + + case reflect.Slice: + if v.IsNil() || v.Len() <= index { + return false + } + value = v.Index(index) + + case reflect.Map: + s.error("reflection support for maps incomplete") + + case reflect.Ptr: + if v.IsNil() { + return false + } + value = v.Elem() + + case reflect.Interface: + if v.IsNil() { + return false + } + value = v.Elem() + + case reflect.Chan: + s.error("reflection support for chans incomplete") + + case reflect.Func: + s.error("reflection support for funcs incomplete") + + default: + s.error(fmt.Sprintf("error: * does not apply to `%s`", value.Type())) + } + + default: + // value is value of named field + var field reflect.Value + if sval := value; sval.Kind() == reflect.Struct { + field = sval.FieldByName(t.fieldName) + if !field.IsValid() { + // TODO consider just returning false in this case + s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type())) + } + } + value = field + } + + // determine rule + ruleName := t.ruleName + if ruleName == "" { + // no alternate rule name, value type determines rule + ruleName = typename(value.Type()) + } + fexpr = s.getFormat(ruleName) + + mark := s.save() + if !s.eval(fexpr, value, index) { + s.restore(mark) + return false + } + return true + + case *group: + // remember current indentation + indentLen := s.indent.Len() + + // update current indentation + mark := s.save() + s.eval(t.indent, value, index) + // if the indentation evaluates to nil, the state's output buffer + // didn't change - either way it's ok to append the difference to + // the current identation + s.indent.Write(s.output.Bytes()[mark.outputLen:s.output.Len()]) + s.restore(mark) + + // format group body + mark = s.save() + b := true + if !s.eval(t.body, value, index) { + s.restore(mark) + b = false + } + + // reset indentation + s.indent.Truncate(indentLen) + return b + + case *option: + // evaluate the body and append the result to the state's output + // buffer unless the result is nil + mark := s.save() + if !s.eval(t.body, value, 0) { // TODO is 0 index correct? + s.restore(mark) + } + return true // an option never evaluates to nil + + case *repetition: + // evaluate the body and append the result to the state's output + // buffer until a result is nil + for i := 0; ; i++ { + mark := s.save() + // write separator, if any + if i > 0 && t.separator != nil { + // nil result from separator is ignored + mark := s.save() + if !s.eval(t.separator, value, i) { + s.restore(mark) + } + } + if !s.eval(t.body, value, i) { + s.restore(mark) + break + } + } + return true // a repetition never evaluates to nil + + case *custom: + // invoke the custom formatter to obtain the result + mark := s.save() + if !t.fun(s, value.Interface(), t.ruleName) { + s.restore(mark) + return false + } + return true + } + + panic("unreachable") + return false +} + + +// Eval formats each argument according to the format +// f and returns the resulting []byte and os.Error. If +// an error occurred, the []byte contains the partially +// formatted result. An environment env may be passed +// in which is available in custom formatters through +// the state parameter. +// +func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { + if f == nil { + return nil, os.NewError("format is nil") + } + + errors := make(chan os.Error) + s := newState(f, env, errors) + + go func() { + for _, v := range args { + fld := reflect.NewValue(v) + if !fld.IsValid() { + errors <- os.NewError("nil argument") + return + } + mark := s.save() + if !s.eval(s.getFormat(typename(fld.Type())), fld, 0) { // TODO is 0 index correct? + s.restore(mark) + } + } + errors <- nil // no errors + }() + + err := <-errors + return s.output.Bytes(), err +} + + +// ---------------------------------------------------------------------------- +// Convenience functions + +// Fprint formats each argument according to the format f +// and writes to w. The result is the total number of bytes +// written and an os.Error, if any. +// +func (f Format) Fprint(w io.Writer, env Environment, args ...interface{}) (int, os.Error) { + data, err := f.Eval(env, args...) + if err != nil { + // TODO should we print partial result in case of error? + return 0, err + } + return w.Write(data) +} + + +// Print formats each argument according to the format f +// and writes to standard output. The result is the total +// number of bytes written and an os.Error, if any. +// +func (f Format) Print(args ...interface{}) (int, os.Error) { + return f.Fprint(os.Stdout, nil, args...) +} + + +// Sprint formats each argument according to the format f +// and returns the resulting string. If an error occurs +// during formatting, the result string contains the +// partially formatted result followed by an error message. +// +func (f Format) Sprint(args ...interface{}) string { + var buf bytes.Buffer + _, err := f.Fprint(&buf, nil, args...) + if err != nil { + var i interface{} = args + fmt.Fprintf(&buf, "--- Sprint(%s) failed: %v", fmt.Sprint(i), err) + } + return buf.String() +} diff --git a/src/cmd/gofix/testdata/reflect.decode.go.in b/src/cmd/gofix/testdata/reflect.decode.go.in new file mode 100644 index 000000000..501230c0c --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.decode.go.in @@ -0,0 +1,907 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "container/vector" + "encoding/base64" + "os" + "reflect" + "runtime" + "strconv" + "strings" + "unicode" + "utf16" + "utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. +// +// Unmarshal traverses the value v recursively. +// If an encountered value implements the Unmarshaler interface, +// Unmarshal calls its UnmarshalJSON method with a well-formed +// JSON encoding. +// +// Otherwise, Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal a JSON value into a nil interface value, the +// type stored in the interface value is one of: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshalling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. +// +func Unmarshal(data []byte, v interface{}) os.Error { + d := new(decodeState).init(data) + + // Quick check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by objects +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid JSON object +// encoding. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalJSON([]byte) os.Error +} + + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to +} + +func (e *UnmarshalTypeError) String() string { + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +type UnmarshalFieldError struct { + Key string + Type *reflect.StructType + Field reflect.StructField +} + +func (e *UnmarshalFieldError) String() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) String() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if _, ok := e.Type.(*reflect.PtrType); !ok { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v interface{}) (err os.Error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(os.Error) + } + }() + + rv := reflect.NewValue(v) + pv, ok := rv.(*reflect.PtrValue) + if !ok || pv.IsNil() { + return &InvalidUnmarshalError{reflect.Typeof(v)} + } + + d.scan.reset() + // We decode rv not pv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + d.value(rv) + return d.savedError +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // read offset in data + scan scanner + nextscan scanner // for calls to nextValue + savedError os.Error +} + +// errPhase is used for errors that should not happen unless +// there is a bug in the JSON decoder or something is editing +// the data slice while the decoder executes. +var errPhase = os.NewError("JSON decoder out of sync - data changing underfoot?") + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + return d +} + +// error aborts the decoding by panicking with err. +func (d *decodeState) error(err os.Error) { + panic(err) +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err os.Error) { + if d.savedError == nil { + d.savedError = err + } +} + +// next cuts off and returns the next full JSON value in d.data[d.off:]. +// The next value is known to be an object or array, not a literal. +func (d *decodeState) next() []byte { + c := d.data[d.off] + item, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // Our scanner has seen the opening brace/bracket + // and thinks we're still in the middle of the object. + // invent a closing brace/bracket to get it out. + if c == '{' { + d.scan.step(&d.scan, '}') + } else { + d.scan.step(&d.scan, ']') + } + + return item +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +// It updates d.off and returns the new scan code. +func (d *decodeState) scanWhile(op int) int { + var newOp int + for { + if d.off >= len(d.data) { + newOp = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } else { + c := int(d.data[d.off]) + d.off++ + newOp = d.scan.step(&d.scan, c) + } + if newOp != op { + break + } + } + return newOp +} + +// value decodes a JSON value from d.data[d.off:] into the value. +// it updates d.off to point past the decoded value. +func (d *decodeState) value(v reflect.Value) { + if v == nil { + _, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // d.scan thinks we're still at the beginning of the item. + // Feed in an empty string - the shortest, simplest value - + // so that it knows we got to the end of the value. + if d.scan.step == stateRedo { + panic("redo") + } + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + return + } + + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(v) + + case scanBeginObject: + d.object(v) + + case scanBeginLiteral: + d.literal(v) + } +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if wantptr is true, indirect stops at the last pointer. +func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, reflect.Value) { + for { + var isUnmarshaler bool + if v.Type().NumMethod() > 0 { + // Remember that this is an unmarshaler, + // but wait to return it until after allocating + // the pointer (if necessary). + _, isUnmarshaler = v.Interface().(Unmarshaler) + } + + if iv, ok := v.(*reflect.InterfaceValue); ok && !iv.IsNil() { + v = iv.Elem() + continue + } + pv, ok := v.(*reflect.PtrValue) + if !ok { + break + } + _, isptrptr := pv.Elem().(*reflect.PtrValue) + if !isptrptr && wantptr && !isUnmarshaler { + return nil, pv + } + if pv.IsNil() { + pv.PointTo(reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem())) + } + if isUnmarshaler { + // Using v.Interface().(Unmarshaler) + // here means that we have to use a pointer + // as the struct field. We cannot use a value inside + // a pointer to a struct, because in that case + // v.Interface() is the value (x.f) not the pointer (&x.f). + // This is an unfortunate consequence of reflect. + // An alternative would be to look up the + // UnmarshalJSON method and return a FuncValue. + return v.Interface().(Unmarshaler), nil + } + v = pv.Elem() + } + return nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into the value v. +// the first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) { + // Check for unmarshaler. + unmarshaler, pv := d.indirect(v, false) + if unmarshaler != nil { + d.off-- + err := unmarshaler.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + iv, ok := v.(*reflect.InterfaceValue) + if ok { + iv.Set(reflect.NewValue(d.arrayInterface())) + return + } + + // Check type of target. + av, ok := v.(reflect.ArrayOrSliceValue) + if !ok { + d.saveError(&UnmarshalTypeError{"array", v.Type()}) + d.off-- + d.next() + return + } + + sv, _ := v.(*reflect.SliceValue) + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + // Get element of array, growing if necessary. + if i >= av.Cap() && sv != nil { + newcap := sv.Cap() + sv.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(sv.Type().(*reflect.SliceType), sv.Len(), newcap) + reflect.Copy(newv, sv) + sv.Set(newv) + } + if i >= av.Len() && sv != nil { + // Must be slice; gave up on array during i >= av.Cap(). + sv.SetLen(i + 1) + } + + // Decode into element. + if i < av.Len() { + d.value(av.Elem(i)) + } else { + // Ran out of fixed array: skip. + d.value(nil) + } + i++ + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + if i < av.Len() { + if sv == nil { + // Array. Zero the rest. + z := reflect.MakeZero(av.Type().(*reflect.ArrayType).Elem()) + for ; i < av.Len(); i++ { + av.Elem(i).SetValue(z) + } + } else { + sv.SetLen(i) + } + } +} + +// matchName returns true if key should be written to a field named name. +func matchName(key, name string) bool { + return strings.ToLower(key) == strings.ToLower(name) +} + +// object consumes an object from d.data[d.off-1:], decoding into the value v. +// the first byte of the object ('{') has been read already. +func (d *decodeState) object(v reflect.Value) { + // Check for unmarshaler. + unmarshaler, pv := d.indirect(v, false) + if unmarshaler != nil { + d.off-- + err := unmarshaler.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + iv, ok := v.(*reflect.InterfaceValue) + if ok { + iv.Set(reflect.NewValue(d.objectInterface())) + return + } + + // Check type of target: struct or map[string]T + var ( + mv *reflect.MapValue + sv *reflect.StructValue + ) + switch v := v.(type) { + case *reflect.MapValue: + // map must have string type + t := v.Type().(*reflect.MapType) + if t.Key() != reflect.Typeof("") { + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + break + } + mv = v + if mv.IsNil() { + mv.SetValue(reflect.MakeMap(t)) + } + case *reflect.StructValue: + sv = v + default: + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + } + + if mv == nil && sv == nil { + d.off-- + d.next() // skip over { } in input + return + } + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Figure out field corresponding to key. + var subv reflect.Value + if mv != nil { + subv = reflect.MakeZero(mv.Type().(*reflect.MapType).Elem()) + } else { + var f reflect.StructField + var ok bool + st := sv.Type().(*reflect.StructType) + // First try for field with that tag. + if isValidTag(key) { + for i := 0; i < sv.NumField(); i++ { + f = st.Field(i) + if f.Tag == key { + ok = true + break + } + } + } + if !ok { + // Second, exact match. + f, ok = st.FieldByName(key) + } + if !ok { + // Third, case-insensitive match. + f, ok = st.FieldByNameFunc(func(s string) bool { return matchName(key, s) }) + } + + // Extract value; name must be exported. + if ok { + if f.PkgPath != "" { + d.saveError(&UnmarshalFieldError{key, st, f}) + } else { + subv = sv.FieldByIndex(f.Index) + } + } + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + d.value(subv) + + // Write value back to map; + // if using struct, subv points into struct already. + if mv != nil { + mv.SetElem(reflect.NewValue(key), subv) + } + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } +} + +// literal consumes a literal from d.data[d.off-1:], decoding into the value v. +// The first byte of the literal has been read already +// (that's how the caller knows it's a literal). +func (d *decodeState) literal(v reflect.Value) { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + // Check for unmarshaler. + wantptr := item[0] == 'n' // null + unmarshaler, pv := d.indirect(v, wantptr) + if unmarshaler != nil { + err := unmarshaler.UnmarshalJSON(item) + if err != nil { + d.error(err) + } + return + } + v = pv + + switch c := item[0]; c { + case 'n': // null + switch v.(type) { + default: + d.saveError(&UnmarshalTypeError{"null", v.Type()}) + case *reflect.InterfaceValue, *reflect.PtrValue, *reflect.MapValue: + v.SetValue(nil) + } + + case 't', 'f': // true, false + value := c == 't' + switch v := v.(type) { + default: + d.saveError(&UnmarshalTypeError{"bool", v.Type()}) + case *reflect.BoolValue: + v.Set(value) + case *reflect.InterfaceValue: + v.Set(reflect.NewValue(value)) + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + d.error(errPhase) + } + switch v := v.(type) { + default: + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + case *reflect.SliceValue: + if v.Type() != byteSliceType { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.Set(reflect.NewValue(b[0:n]).(*reflect.SliceValue)) + case *reflect.StringValue: + v.Set(string(s)) + case *reflect.InterfaceValue: + v.Set(reflect.NewValue(string(s))) + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + s := string(item) + switch v := v.(type) { + default: + d.error(&UnmarshalTypeError{"number", v.Type()}) + case *reflect.InterfaceValue: + n, err := strconv.Atof64(s) + if err != nil { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.Set(reflect.NewValue(n)) + + case *reflect.IntValue: + n, err := strconv.Atoi64(s) + if err != nil || v.Overflow(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.Set(n) + + case *reflect.UintValue: + n, err := strconv.Atoui64(s) + if err != nil || v.Overflow(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.Set(n) + + case *reflect.FloatValue: + n, err := strconv.AtofN(s, v.Type().Bits()) + if err != nil || v.Overflow(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.Set(n) + } + } +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() interface{} { + switch d.scanWhile(scanSkipSpace) { + default: + d.error(errPhase) + case scanBeginArray: + return d.arrayInterface() + case scanBeginObject: + return d.objectInterface() + case scanBeginLiteral: + return d.literalInterface() + } + panic("unreachable") +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []interface{} { + var v vector.Vector + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + v.Push(d.valueInterface()) + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]interface{} { + m := make(map[string]interface{}) + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } + return m +} + + +// literalInterface is like literal but returns an interface value. +func (d *decodeState) literalInterface() interface{} { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + d.error(errPhase) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + n, err := strconv.Atof64(string(item)) + if err != nil { + d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)}) + } + return n + } + panic("unreachable") +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) int { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + rune, err := strconv.Btoui64(string(s[2:6]), 16) + if err != nil { + return -1 + } + return int(rune) +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rune, size := utf8.DecodeRune(s[r:]) + if rune == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rune := getu4(s[r:]) + if rune < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rune) { + rune1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rune, rune1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rune = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rune) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rune, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rune) + } + } + return b[0:w], true +} diff --git a/src/cmd/gofix/testdata/reflect.decode.go.out b/src/cmd/gofix/testdata/reflect.decode.go.out new file mode 100644 index 000000000..a5fd33912 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.decode.go.out @@ -0,0 +1,910 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "container/vector" + "encoding/base64" + "os" + "reflect" + "runtime" + "strconv" + "strings" + "unicode" + "utf16" + "utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. +// +// Unmarshal traverses the value v recursively. +// If an encountered value implements the Unmarshaler interface, +// Unmarshal calls its UnmarshalJSON method with a well-formed +// JSON encoding. +// +// Otherwise, Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal a JSON value into a nil interface value, the +// type stored in the interface value is one of: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshalling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. +// +func Unmarshal(data []byte, v interface{}) os.Error { + d := new(decodeState).init(data) + + // Quick check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by objects +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid JSON object +// encoding. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +type Unmarshaler interface { + UnmarshalJSON([]byte) os.Error +} + + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to +} + +func (e *UnmarshalTypeError) String() string { + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) String() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) String() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v interface{}) (err os.Error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(os.Error) + } + }() + + rv := reflect.NewValue(v) + pv := rv + if pv.Kind() != reflect.Ptr || + pv.IsNil() { + return &InvalidUnmarshalError{reflect.Typeof(v)} + } + + d.scan.reset() + // We decode rv not pv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + d.value(rv) + return d.savedError +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // read offset in data + scan scanner + nextscan scanner // for calls to nextValue + savedError os.Error +} + +// errPhase is used for errors that should not happen unless +// there is a bug in the JSON decoder or something is editing +// the data slice while the decoder executes. +var errPhase = os.NewError("JSON decoder out of sync - data changing underfoot?") + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + return d +} + +// error aborts the decoding by panicking with err. +func (d *decodeState) error(err os.Error) { + panic(err) +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err os.Error) { + if d.savedError == nil { + d.savedError = err + } +} + +// next cuts off and returns the next full JSON value in d.data[d.off:]. +// The next value is known to be an object or array, not a literal. +func (d *decodeState) next() []byte { + c := d.data[d.off] + item, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // Our scanner has seen the opening brace/bracket + // and thinks we're still in the middle of the object. + // invent a closing brace/bracket to get it out. + if c == '{' { + d.scan.step(&d.scan, '}') + } else { + d.scan.step(&d.scan, ']') + } + + return item +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +// It updates d.off and returns the new scan code. +func (d *decodeState) scanWhile(op int) int { + var newOp int + for { + if d.off >= len(d.data) { + newOp = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } else { + c := int(d.data[d.off]) + d.off++ + newOp = d.scan.step(&d.scan, c) + } + if newOp != op { + break + } + } + return newOp +} + +// value decodes a JSON value from d.data[d.off:] into the value. +// it updates d.off to point past the decoded value. +func (d *decodeState) value(v reflect.Value) { + if !v.IsValid() { + _, rest, err := nextValue(d.data[d.off:], &d.nextscan) + if err != nil { + d.error(err) + } + d.off = len(d.data) - len(rest) + + // d.scan thinks we're still at the beginning of the item. + // Feed in an empty string - the shortest, simplest value - + // so that it knows we got to the end of the value. + if d.scan.step == stateRedo { + panic("redo") + } + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + return + } + + switch op := d.scanWhile(scanSkipSpace); op { + default: + d.error(errPhase) + + case scanBeginArray: + d.array(v) + + case scanBeginObject: + d.object(v) + + case scanBeginLiteral: + d.literal(v) + } +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if wantptr is true, indirect stops at the last pointer. +func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, reflect.Value) { + for { + var isUnmarshaler bool + if v.Type().NumMethod() > 0 { + // Remember that this is an unmarshaler, + // but wait to return it until after allocating + // the pointer (if necessary). + _, isUnmarshaler = v.Interface().(Unmarshaler) + } + + if iv := v; iv.Kind() == reflect.Interface && !iv.IsNil() { + v = iv.Elem() + continue + } + pv := v + if pv.Kind() != reflect.Ptr { + break + } + + if pv.Elem().Kind() != reflect.Ptr && + wantptr && !isUnmarshaler { + return nil, pv + } + if pv.IsNil() { + pv.Set(reflect.Zero(pv.Type().Elem()).Addr()) + } + if isUnmarshaler { + // Using v.Interface().(Unmarshaler) + // here means that we have to use a pointer + // as the struct field. We cannot use a value inside + // a pointer to a struct, because in that case + // v.Interface() is the value (x.f) not the pointer (&x.f). + // This is an unfortunate consequence of reflect. + // An alternative would be to look up the + // UnmarshalJSON method and return a FuncValue. + return v.Interface().(Unmarshaler), reflect.Value{} + } + v = pv.Elem() + } + return nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into the value v. +// the first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) { + // Check for unmarshaler. + unmarshaler, pv := d.indirect(v, false) + if unmarshaler != nil { + d.off-- + err := unmarshaler.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + iv := v + ok := iv.Kind() == reflect.Interface + if ok { + iv.Set(reflect.NewValue(d.arrayInterface())) + return + } + + // Check type of target. + av := v + if av.Kind() != reflect.Array && av.Kind() != reflect.Slice { + d.saveError(&UnmarshalTypeError{"array", v.Type()}) + d.off-- + d.next() + return + } + + sv := v + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + // Get element of array, growing if necessary. + if i >= av.Cap() && sv.IsValid() { + newcap := sv.Cap() + sv.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(sv.Type(), sv.Len(), newcap) + reflect.Copy(newv, sv) + sv.Set(newv) + } + if i >= av.Len() && sv.IsValid() { + // Must be slice; gave up on array during i >= av.Cap(). + sv.SetLen(i + 1) + } + + // Decode into element. + if i < av.Len() { + d.value(av.Index(i)) + } else { + // Ran out of fixed array: skip. + d.value(reflect.Value{}) + } + i++ + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + if i < av.Len() { + if !sv.IsValid() { + // Array. Zero the rest. + z := reflect.Zero(av.Type().Elem()) + for ; i < av.Len(); i++ { + av.Index(i).Set(z) + } + } else { + sv.SetLen(i) + } + } +} + +// matchName returns true if key should be written to a field named name. +func matchName(key, name string) bool { + return strings.ToLower(key) == strings.ToLower(name) +} + +// object consumes an object from d.data[d.off-1:], decoding into the value v. +// the first byte of the object ('{') has been read already. +func (d *decodeState) object(v reflect.Value) { + // Check for unmarshaler. + unmarshaler, pv := d.indirect(v, false) + if unmarshaler != nil { + d.off-- + err := unmarshaler.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + iv := v + if iv.Kind() == reflect.Interface { + iv.Set(reflect.NewValue(d.objectInterface())) + return + } + + // Check type of target: struct or map[string]T + var ( + mv reflect.Value + sv reflect.Value + ) + switch v.Kind() { + case reflect.Map: + // map must have string type + t := v.Type() + if t.Key() != reflect.Typeof("") { + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + break + } + mv = v + if mv.IsNil() { + mv.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + sv = v + default: + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + } + + if !mv.IsValid() && !sv.IsValid() { + d.off-- + d.next() // skip over { } in input + return + } + + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Figure out field corresponding to key. + var subv reflect.Value + if mv.IsValid() { + subv = reflect.Zero(mv.Type().Elem()) + } else { + var f reflect.StructField + var ok bool + st := sv.Type() + // First try for field with that tag. + if isValidTag(key) { + for i := 0; i < sv.NumField(); i++ { + f = st.Field(i) + if f.Tag == key { + ok = true + break + } + } + } + if !ok { + // Second, exact match. + f, ok = st.FieldByName(key) + } + if !ok { + // Third, case-insensitive match. + f, ok = st.FieldByNameFunc(func(s string) bool { return matchName(key, s) }) + } + + // Extract value; name must be exported. + if ok { + if f.PkgPath != "" { + d.saveError(&UnmarshalFieldError{key, st, f}) + } else { + subv = sv.FieldByIndex(f.Index) + } + } + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + d.value(subv) + + // Write value back to map; + // if using struct, subv points into struct already. + if mv.IsValid() { + mv.SetMapIndex(reflect.NewValue(key), subv) + } + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } +} + +// literal consumes a literal from d.data[d.off-1:], decoding into the value v. +// The first byte of the literal has been read already +// (that's how the caller knows it's a literal). +func (d *decodeState) literal(v reflect.Value) { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + // Check for unmarshaler. + wantptr := item[0] == 'n' // null + unmarshaler, pv := d.indirect(v, wantptr) + if unmarshaler != nil { + err := unmarshaler.UnmarshalJSON(item) + if err != nil { + d.error(err) + } + return + } + v = pv + + switch c := item[0]; c { + case 'n': // null + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"null", v.Type()}) + case reflect.Interface, reflect.Ptr, reflect.Map: + v.Set(reflect.Zero(v.Type())) + } + + case 't', 'f': // true, false + value := c == 't' + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"bool", v.Type()}) + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + v.Set(reflect.NewValue(value)) + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + d.error(errPhase) + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + case reflect.Slice: + if v.Type() != byteSliceType { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.Set(reflect.NewValue(b[0:n])) + case reflect.String: + v.SetString(string(s)) + case reflect.Interface: + v.Set(reflect.NewValue(string(s))) + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + s := string(item) + switch v.Kind() { + default: + d.error(&UnmarshalTypeError{"number", v.Type()}) + case reflect.Interface: + n, err := strconv.Atof64(s) + if err != nil { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.Set(reflect.NewValue(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.Atoi64(s) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.Atoui64(s) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.AtofN(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) + break + } + v.SetFloat(n) + } + } +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() interface{} { + switch d.scanWhile(scanSkipSpace) { + default: + d.error(errPhase) + case scanBeginArray: + return d.arrayInterface() + case scanBeginObject: + return d.objectInterface() + case scanBeginLiteral: + return d.literalInterface() + } + panic("unreachable") +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []interface{} { + var v vector.Vector + for { + // Look ahead for ] - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + v.Push(d.valueInterface()) + + // Next token must be , or ]. + op = d.scanWhile(scanSkipSpace) + if op == scanEndArray { + break + } + if op != scanArrayValue { + d.error(errPhase) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]interface{} { + m := make(map[string]interface{}) + for { + // Read opening " of string key or closing }. + op := d.scanWhile(scanSkipSpace) + if op == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if op != scanBeginLiteral { + d.error(errPhase) + } + + // Read string key. + start := d.off - 1 + op = d.scanWhile(scanContinue) + item := d.data[start : d.off-1] + key, ok := unquote(item) + if !ok { + d.error(errPhase) + } + + // Read : before value. + if op == scanSkipSpace { + op = d.scanWhile(scanSkipSpace) + } + if op != scanObjectKey { + d.error(errPhase) + } + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + op = d.scanWhile(scanSkipSpace) + if op == scanEndObject { + break + } + if op != scanObjectValue { + d.error(errPhase) + } + } + return m +} + + +// literalInterface is like literal but returns an interface value. +func (d *decodeState) literalInterface() interface{} { + // All bytes inside literal return scanContinue op code. + start := d.off - 1 + op := d.scanWhile(scanContinue) + + // Scan read one byte too far; back up. + d.off-- + d.scan.undo(op) + item := d.data[start:d.off] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + d.error(errPhase) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + d.error(errPhase) + } + n, err := strconv.Atof64(string(item)) + if err != nil { + d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)}) + } + return n + } + panic("unreachable") +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) int { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + rune, err := strconv.Btoui64(string(s[2:6]), 16) + if err != nil { + return -1 + } + return int(rune) +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rune, size := utf8.DecodeRune(s[r:]) + if rune == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rune := getu4(s[r:]) + if rune < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rune) { + rune1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rune, rune1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rune = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rune) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rune, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rune) + } + } + return b[0:w], true +} diff --git a/src/cmd/gofix/testdata/reflect.decoder.go.in b/src/cmd/gofix/testdata/reflect.decoder.go.in new file mode 100644 index 000000000..34364161a --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.decoder.go.in @@ -0,0 +1,196 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bufio" + "bytes" + "io" + "os" + "reflect" + "sync" +) + +// A Decoder manages the receipt of type and data information read from the +// remote side of a connection. +type Decoder struct { + mutex sync.Mutex // each item must be received atomically + r io.Reader // source of the data + buf bytes.Buffer // buffer for more efficient i/o from r + wireType map[typeId]*wireType // map from remote ID to local description + decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines + ignorerCache map[typeId]**decEngine // ditto for ignored objects + freeList *decoderState // list of free decoderStates; avoids reallocation + countBuf []byte // used for decoding integers while parsing messages + tmp []byte // temporary storage for i/o; saves reallocating + err os.Error +} + +// NewDecoder returns a new decoder that reads from the io.Reader. +func NewDecoder(r io.Reader) *Decoder { + dec := new(Decoder) + dec.r = bufio.NewReader(r) + dec.wireType = make(map[typeId]*wireType) + dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine) + dec.ignorerCache = make(map[typeId]**decEngine) + dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes + + return dec +} + +// recvType loads the definition of a type. +func (dec *Decoder) recvType(id typeId) { + // Have we already seen this type? That's an error + if id < firstUserId || dec.wireType[id] != nil { + dec.err = os.ErrorString("gob: duplicate type received") + return + } + + // Type: + wire := new(wireType) + dec.decodeValue(tWireType, reflect.NewValue(wire)) + if dec.err != nil { + return + } + // Remember we've seen this type. + dec.wireType[id] = wire +} + +// recvMessage reads the next count-delimited item from the input. It is the converse +// of Encoder.writeMessage. It returns false on EOF or other error reading the message. +func (dec *Decoder) recvMessage() bool { + // Read a count. + nbytes, _, err := decodeUintReader(dec.r, dec.countBuf) + if err != nil { + dec.err = err + return false + } + dec.readMessage(int(nbytes)) + return dec.err == nil +} + +// readMessage reads the next nbytes bytes from the input. +func (dec *Decoder) readMessage(nbytes int) { + // Allocate the buffer. + if cap(dec.tmp) < nbytes { + dec.tmp = make([]byte, nbytes+100) // room to grow + } + dec.tmp = dec.tmp[:nbytes] + + // Read the data + _, dec.err = io.ReadFull(dec.r, dec.tmp) + if dec.err != nil { + if dec.err == os.EOF { + dec.err = io.ErrUnexpectedEOF + } + return + } + dec.buf.Write(dec.tmp) +} + +// toInt turns an encoded uint64 into an int, according to the marshaling rules. +func toInt(x uint64) int64 { + i := int64(x >> 1) + if x&1 != 0 { + i = ^i + } + return i +} + +func (dec *Decoder) nextInt() int64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return toInt(n) +} + +func (dec *Decoder) nextUint() uint64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return n +} + +// decodeTypeSequence parses: +// TypeSequence +// (TypeDefinition DelimitedTypeDefinition*)? +// and returns the type id of the next value. It returns -1 at +// EOF. Upon return, the remainder of dec.buf is the value to be +// decoded. If this is an interface value, it can be ignored by +// simply resetting that buffer. +func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId { + for dec.err == nil { + if dec.buf.Len() == 0 { + if !dec.recvMessage() { + break + } + } + // Receive a type id. + id := typeId(dec.nextInt()) + if id >= 0 { + // Value follows. + return id + } + // Type definition for (-id) follows. + dec.recvType(-id) + // When decoding an interface, after a type there may be a + // DelimitedValue still in the buffer. Skip its count. + // (Alternatively, the buffer is empty and the byte count + // will be absorbed by recvMessage.) + if dec.buf.Len() > 0 { + if !isInterface { + dec.err = os.ErrorString("extra data in buffer") + break + } + dec.nextUint() + } + } + return -1 +} + +// Decode reads the next value from the connection and stores +// it in the data represented by the empty interface value. +// If e is nil, the value will be discarded. Otherwise, +// the value underlying e must either be the correct type for the next +// data item received, and must be a pointer. +func (dec *Decoder) Decode(e interface{}) os.Error { + if e == nil { + return dec.DecodeValue(nil) + } + value := reflect.NewValue(e) + // If e represents a value as opposed to a pointer, the answer won't + // get back to the caller. Make sure it's a pointer. + if value.Type().Kind() != reflect.Ptr { + dec.err = os.ErrorString("gob: attempt to decode into a non-pointer") + return dec.err + } + return dec.DecodeValue(value) +} + +// DecodeValue reads the next value from the connection and stores +// it in the data represented by the reflection value. +// The value must be the correct type for the next +// data item received, or it may be nil, which means the +// value will be discarded. +func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { + // Make sure we're single-threaded through here. + dec.mutex.Lock() + defer dec.mutex.Unlock() + + dec.buf.Reset() // In case data lingers from previous invocation. + dec.err = nil + id := dec.decodeTypeSequence(false) + if dec.err == nil { + dec.decodeValue(id, value) + } + return dec.err +} + +// If debug.go is compiled into the program , debugFunc prints a human-readable +// representation of the gob data read from r by calling that file's Debug function. +// Otherwise it is nil. +var debugFunc func(io.Reader) diff --git a/src/cmd/gofix/testdata/reflect.decoder.go.out b/src/cmd/gofix/testdata/reflect.decoder.go.out new file mode 100644 index 000000000..a631c27a2 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.decoder.go.out @@ -0,0 +1,196 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bufio" + "bytes" + "io" + "os" + "reflect" + "sync" +) + +// A Decoder manages the receipt of type and data information read from the +// remote side of a connection. +type Decoder struct { + mutex sync.Mutex // each item must be received atomically + r io.Reader // source of the data + buf bytes.Buffer // buffer for more efficient i/o from r + wireType map[typeId]*wireType // map from remote ID to local description + decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines + ignorerCache map[typeId]**decEngine // ditto for ignored objects + freeList *decoderState // list of free decoderStates; avoids reallocation + countBuf []byte // used for decoding integers while parsing messages + tmp []byte // temporary storage for i/o; saves reallocating + err os.Error +} + +// NewDecoder returns a new decoder that reads from the io.Reader. +func NewDecoder(r io.Reader) *Decoder { + dec := new(Decoder) + dec.r = bufio.NewReader(r) + dec.wireType = make(map[typeId]*wireType) + dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine) + dec.ignorerCache = make(map[typeId]**decEngine) + dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes + + return dec +} + +// recvType loads the definition of a type. +func (dec *Decoder) recvType(id typeId) { + // Have we already seen this type? That's an error + if id < firstUserId || dec.wireType[id] != nil { + dec.err = os.ErrorString("gob: duplicate type received") + return + } + + // Type: + wire := new(wireType) + dec.decodeValue(tWireType, reflect.NewValue(wire)) + if dec.err != nil { + return + } + // Remember we've seen this type. + dec.wireType[id] = wire +} + +// recvMessage reads the next count-delimited item from the input. It is the converse +// of Encoder.writeMessage. It returns false on EOF or other error reading the message. +func (dec *Decoder) recvMessage() bool { + // Read a count. + nbytes, _, err := decodeUintReader(dec.r, dec.countBuf) + if err != nil { + dec.err = err + return false + } + dec.readMessage(int(nbytes)) + return dec.err == nil +} + +// readMessage reads the next nbytes bytes from the input. +func (dec *Decoder) readMessage(nbytes int) { + // Allocate the buffer. + if cap(dec.tmp) < nbytes { + dec.tmp = make([]byte, nbytes+100) // room to grow + } + dec.tmp = dec.tmp[:nbytes] + + // Read the data + _, dec.err = io.ReadFull(dec.r, dec.tmp) + if dec.err != nil { + if dec.err == os.EOF { + dec.err = io.ErrUnexpectedEOF + } + return + } + dec.buf.Write(dec.tmp) +} + +// toInt turns an encoded uint64 into an int, according to the marshaling rules. +func toInt(x uint64) int64 { + i := int64(x >> 1) + if x&1 != 0 { + i = ^i + } + return i +} + +func (dec *Decoder) nextInt() int64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return toInt(n) +} + +func (dec *Decoder) nextUint() uint64 { + n, _, err := decodeUintReader(&dec.buf, dec.countBuf) + if err != nil { + dec.err = err + } + return n +} + +// decodeTypeSequence parses: +// TypeSequence +// (TypeDefinition DelimitedTypeDefinition*)? +// and returns the type id of the next value. It returns -1 at +// EOF. Upon return, the remainder of dec.buf is the value to be +// decoded. If this is an interface value, it can be ignored by +// simply resetting that buffer. +func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId { + for dec.err == nil { + if dec.buf.Len() == 0 { + if !dec.recvMessage() { + break + } + } + // Receive a type id. + id := typeId(dec.nextInt()) + if id >= 0 { + // Value follows. + return id + } + // Type definition for (-id) follows. + dec.recvType(-id) + // When decoding an interface, after a type there may be a + // DelimitedValue still in the buffer. Skip its count. + // (Alternatively, the buffer is empty and the byte count + // will be absorbed by recvMessage.) + if dec.buf.Len() > 0 { + if !isInterface { + dec.err = os.ErrorString("extra data in buffer") + break + } + dec.nextUint() + } + } + return -1 +} + +// Decode reads the next value from the connection and stores +// it in the data represented by the empty interface value. +// If e is nil, the value will be discarded. Otherwise, +// the value underlying e must either be the correct type for the next +// data item received, and must be a pointer. +func (dec *Decoder) Decode(e interface{}) os.Error { + if e == nil { + return dec.DecodeValue(reflect.Value{}) + } + value := reflect.NewValue(e) + // If e represents a value as opposed to a pointer, the answer won't + // get back to the caller. Make sure it's a pointer. + if value.Type().Kind() != reflect.Ptr { + dec.err = os.ErrorString("gob: attempt to decode into a non-pointer") + return dec.err + } + return dec.DecodeValue(value) +} + +// DecodeValue reads the next value from the connection and stores +// it in the data represented by the reflection value. +// The value must be the correct type for the next +// data item received, or it may be nil, which means the +// value will be discarded. +func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { + // Make sure we're single-threaded through here. + dec.mutex.Lock() + defer dec.mutex.Unlock() + + dec.buf.Reset() // In case data lingers from previous invocation. + dec.err = nil + id := dec.decodeTypeSequence(false) + if dec.err == nil { + dec.decodeValue(id, value) + } + return dec.err +} + +// If debug.go is compiled into the program , debugFunc prints a human-readable +// representation of the gob data read from r by calling that file's Debug function. +// Otherwise it is nil. +var debugFunc func(io.Reader) diff --git a/src/cmd/gofix/testdata/reflect.dnsmsg.go.in b/src/cmd/gofix/testdata/reflect.dnsmsg.go.in new file mode 100644 index 000000000..5209c1a06 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.dnsmsg.go.in @@ -0,0 +1,779 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS packet assembly. See RFC 1035. +// +// This is intended to support name resolution during net.Dial. +// It doesn't have to be blazing fast. +// +// Rather than write the usual handful of routines to pack and +// unpack every message that can appear on the wire, we use +// reflection to write a generic pack/unpack for structs and then +// use it. Thus, if in the future we need to define new message +// structs, no new pack/unpack/printing code needs to be written. +// +// The first half of this file defines the DNS message formats. +// The second half implements the conversion to and from wire format. +// A few of the structure elements have string tags to aid the +// generic pack/unpack routines. +// +// TODO(rsc): There are enough names defined in this file that they're all +// prefixed with dns. Perhaps put this in its own package later. + +package net + +import ( + "fmt" + "os" + "reflect" +) + +// Packet formats + +// Wire constants. +const ( + // valid dnsRR_Header.Rrtype and dnsQuestion.qtype + dnsTypeA = 1 + dnsTypeNS = 2 + dnsTypeMD = 3 + dnsTypeMF = 4 + dnsTypeCNAME = 5 + dnsTypeSOA = 6 + dnsTypeMB = 7 + dnsTypeMG = 8 + dnsTypeMR = 9 + dnsTypeNULL = 10 + dnsTypeWKS = 11 + dnsTypePTR = 12 + dnsTypeHINFO = 13 + dnsTypeMINFO = 14 + dnsTypeMX = 15 + dnsTypeTXT = 16 + dnsTypeAAAA = 28 + dnsTypeSRV = 33 + + // valid dnsQuestion.qtype only + dnsTypeAXFR = 252 + dnsTypeMAILB = 253 + dnsTypeMAILA = 254 + dnsTypeALL = 255 + + // valid dnsQuestion.qclass + dnsClassINET = 1 + dnsClassCSNET = 2 + dnsClassCHAOS = 3 + dnsClassHESIOD = 4 + dnsClassANY = 255 + + // dnsMsg.rcode + dnsRcodeSuccess = 0 + dnsRcodeFormatError = 1 + dnsRcodeServerFailure = 2 + dnsRcodeNameError = 3 + dnsRcodeNotImplemented = 4 + dnsRcodeRefused = 5 +) + +// The wire format for the DNS packet header. +type dnsHeader struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +const ( + // dnsHeader.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available +) + +// DNS queries. +type dnsQuestion struct { + Name string "domain-name" // "domain-name" specifies encoding; see packers below + Qtype uint16 + Qclass uint16 +} + +// DNS responses (resource records). +// There are many types of messages, +// but they all share the same header. +type dnsRR_Header struct { + Name string "domain-name" + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *dnsRR_Header) Header() *dnsRR_Header { + return h +} + +type dnsRR interface { + Header() *dnsRR_Header +} + + +// Specific DNS RR formats for each query type. + +type dnsRR_CNAME struct { + Hdr dnsRR_Header + Cname string "domain-name" +} + +func (rr *dnsRR_CNAME) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_HINFO struct { + Hdr dnsRR_Header + Cpu string + Os string +} + +func (rr *dnsRR_HINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MB struct { + Hdr dnsRR_Header + Mb string "domain-name" +} + +func (rr *dnsRR_MB) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MG struct { + Hdr dnsRR_Header + Mg string "domain-name" +} + +func (rr *dnsRR_MG) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MINFO struct { + Hdr dnsRR_Header + Rmail string "domain-name" + Email string "domain-name" +} + +func (rr *dnsRR_MINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MR struct { + Hdr dnsRR_Header + Mr string "domain-name" +} + +func (rr *dnsRR_MR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MX struct { + Hdr dnsRR_Header + Pref uint16 + Mx string "domain-name" +} + +func (rr *dnsRR_MX) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_NS struct { + Hdr dnsRR_Header + Ns string "domain-name" +} + +func (rr *dnsRR_NS) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_PTR struct { + Hdr dnsRR_Header + Ptr string "domain-name" +} + +func (rr *dnsRR_PTR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SOA struct { + Hdr dnsRR_Header + Ns string "domain-name" + Mbox string "domain-name" + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *dnsRR_SOA) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_TXT struct { + Hdr dnsRR_Header + Txt string // not domain name +} + +func (rr *dnsRR_TXT) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SRV struct { + Hdr dnsRR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string "domain-name" +} + +func (rr *dnsRR_SRV) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_A struct { + Hdr dnsRR_Header + A uint32 "ipv4" +} + +func (rr *dnsRR_A) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_AAAA struct { + Hdr dnsRR_Header + AAAA [16]byte "ipv6" +} + +func (rr *dnsRR_AAAA) Header() *dnsRR_Header { + return &rr.Hdr +} + +// Packing and unpacking. +// +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +// Map of constructors for each RR wire type. +var rr_mk = map[int]func() dnsRR{ + dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) }, + dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) }, + dnsTypeMB: func() dnsRR { return new(dnsRR_MB) }, + dnsTypeMG: func() dnsRR { return new(dnsRR_MG) }, + dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) }, + dnsTypeMR: func() dnsRR { return new(dnsRR_MR) }, + dnsTypeMX: func() dnsRR { return new(dnsRR_MX) }, + dnsTypeNS: func() dnsRR { return new(dnsRR_NS) }, + dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) }, + dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) }, + dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) }, + dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) }, + dnsTypeA: func() dnsRR { return new(dnsRR_A) }, + dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) }, +} + +// Pack a domain name s into msg[off:]. +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. +func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { + // Add trailing dot to canonicalize name. + if n := len(s); n == 0 || s[n-1] != '.' { + s += "." + } + + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // There is also a trailing zero. + // Check that we have all the space we need. + tot := len(s) + 1 + if off+tot > len(msg) { + return len(msg), false + } + + // Emit sequence of counted strings, chopping at dots. + begin := 0 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + if i-begin >= 1<<6 { // top two bits of length must be clear + return len(msg), false + } + msg[off] = byte(i - begin) + off++ + for j := begin; j < i; j++ { + msg[off] = s[j] + off++ + } + begin = i + 1 + } + } + msg[off] = 0 + off++ + return off, true +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { + s = "" + ptr := 0 // number of pointers followed +Loop: + for { + if off >= len(msg) { + return "", len(msg), false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + break Loop + } + // literal string + if off+c > len(msg) { + return "", len(msg), false + } + s += string(msg[off:off+c]) + "." + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= len(msg) { + return "", len(msg), false + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", len(msg), false + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", len(msg), false + } + } + if ptr == 0 { + off1 = off + } + return s, off1, true +} + +// TODO(rsc): Move into generic library? +// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, +// [n]byte, and other (often anonymous) structs. +func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().(*reflect.StructType).Field(i) + switch fv := val.Field(i).(type) { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case *reflect.StructValue: + off, ok = packStructValue(fv, msg, off) + case *reflect.UintValue: + i := fv.Get() + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + } + case *reflect.ArrayValue: + if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 { + goto BadType + } + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue), fv) + off += n + case *reflect.StringValue: + // There are multiple string encodings. + // The tag distinguishes ordinary strings from domain names. + s := fv.Get() + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "domain-name": + off, ok = packDomainName(s, msg, off) + if !ok { + return len(msg), false + } + case "": + // Counted string: 1 byte length. + if len(s) > 255 || off+1+len(s) > len(msg) { + return len(msg), false + } + msg[off] = byte(len(s)) + off++ + off += copy(msg[off:], s) + } + } + } + return off, true +} + +func structValue(any interface{}) *reflect.StructValue { + return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue) +} + +func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = packStructValue(structValue(any), msg, off) + return off, ok +} + +// TODO(rsc): Move into generic library? +// Unpack a reflect.StructValue from msg. +// Same restrictions as packStructValue. +func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().(*reflect.StructType).Field(i) + switch fv := val.Field(i).(type) { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case *reflect.StructValue: + off, ok = unpackStructValue(fv, msg, off) + case *reflect.UintValue: + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.Set(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.Set(uint64(i)) + off += 4 + } + case *reflect.ArrayValue: + if fv.Type().(*reflect.ArrayType).Elem().Kind() != reflect.Uint8 { + goto BadType + } + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(fv, reflect.NewValue(msg[off:off+n]).(*reflect.SliceValue)) + off += n + case *reflect.StringValue: + var s string + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "domain-name": + s, off, ok = unpackDomainName(msg, off) + if !ok { + return len(msg), false + } + case "": + if off >= len(msg) || off+1+int(msg[off]) > len(msg) { + return len(msg), false + } + n := int(msg[off]) + off++ + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = msg[off+i] + } + off += n + s = string(b) + } + fv.Set(s) + } + } + return off, true +} + +func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = unpackStructValue(structValue(any), msg, off) + return off, ok +} + +// Generic struct printer. +// Doesn't care about the string tag "domain-name", +// but does look for an "ipv4" tag on uint32 variables +// and the "ipv6" tag on array variables, +// printing them as IP addresses. +func printStructValue(val *reflect.StructValue) string { + s := "{" + for i := 0; i < val.NumField(); i++ { + if i > 0 { + s += ", " + } + f := val.Type().(*reflect.StructType).Field(i) + if !f.Anonymous { + s += f.Name + "=" + } + fval := val.Field(i) + if fv, ok := fval.(*reflect.StructValue); ok { + s += printStructValue(fv) + } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" { + i := fv.Get() + s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + } else if fv, ok := fval.(*reflect.ArrayValue); ok && f.Tag == "ipv6" { + i := fv.Interface().([]byte) + s += IP(i).String() + } else { + s += fmt.Sprint(fval.Interface()) + } + } + s += "}" + return s +} + +func printStruct(any interface{}) string { return printStructValue(structValue(any)) } + +// Resource record packer. +func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) { + var off1 int + // pack twice, once to find end of header + // and again to find end of packet. + // a bit inefficient but this doesn't need to be fast. + // off1 is end of header + // off2 is end of rr + off1, ok = packStruct(rr.Header(), msg, off) + off2, ok = packStruct(rr, msg, off) + if !ok { + return len(msg), false + } + // pack a third time; redo header with correct data length + rr.Header().Rdlength = uint16(off2 - off1) + packStruct(rr.Header(), msg, off) + return off2, true +} + +// Resource record unpacker. +func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) { + // unpack just the header, to find the rr type and length + var h dnsRR_Header + off0 := off + if off, ok = unpackStruct(&h, msg, off); !ok { + return nil, len(msg), false + } + end := off + int(h.Rdlength) + + // make an rr of that type and re-unpack. + // again inefficient but doesn't need to be fast. + mk, known := rr_mk[int(h.Rrtype)] + if !known { + return &h, end, true + } + rr = mk() + off, ok = unpackStruct(rr, msg, off0) + if off != end { + return &h, end, true + } + return rr, off, ok +} + +// Usable representation of a DNS packet. + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type dnsMsgHdr struct { + id uint16 + response bool + opcode int + authoritative bool + truncated bool + recursion_desired bool + recursion_available bool + rcode int +} + +type dnsMsg struct { + dnsMsgHdr + question []dnsQuestion + answer []dnsRR + ns []dnsRR + extra []dnsRR +} + + +func (dns *dnsMsg) Pack() (msg []byte, ok bool) { + var dh dnsHeader + + // Convert convenient dnsMsg into wire-like dnsHeader. + dh.Id = dns.id + dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode) + if dns.recursion_available { + dh.Bits |= _RA + } + if dns.recursion_desired { + dh.Bits |= _RD + } + if dns.truncated { + dh.Bits |= _TC + } + if dns.authoritative { + dh.Bits |= _AA + } + if dns.response { + dh.Bits |= _QR + } + + // Prepare variable sized arrays. + question := dns.question + answer := dns.answer + ns := dns.ns + extra := dns.extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // Could work harder to calculate message size, + // but this is far more than we need and not + // big enough to hurt the allocator. + msg = make([]byte, 2000) + + // Pack it in: header and then the pieces. + off := 0 + off, ok = packStruct(&dh, msg, off) + for i := 0; i < len(question); i++ { + off, ok = packStruct(&question[i], msg, off) + } + for i := 0; i < len(answer); i++ { + off, ok = packRR(answer[i], msg, off) + } + for i := 0; i < len(ns); i++ { + off, ok = packRR(ns[i], msg, off) + } + for i := 0; i < len(extra); i++ { + off, ok = packRR(extra[i], msg, off) + } + if !ok { + return nil, false + } + return msg[0:off], true +} + +func (dns *dnsMsg) Unpack(msg []byte) bool { + // Header. + var dh dnsHeader + off := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return false + } + dns.id = dh.Id + dns.response = (dh.Bits & _QR) != 0 + dns.opcode = int(dh.Bits>>11) & 0xF + dns.authoritative = (dh.Bits & _AA) != 0 + dns.truncated = (dh.Bits & _TC) != 0 + dns.recursion_desired = (dh.Bits & _RD) != 0 + dns.recursion_available = (dh.Bits & _RA) != 0 + dns.rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.question = make([]dnsQuestion, dh.Qdcount) + dns.answer = make([]dnsRR, dh.Ancount) + dns.ns = make([]dnsRR, dh.Nscount) + dns.extra = make([]dnsRR, dh.Arcount) + + for i := 0; i < len(dns.question); i++ { + off, ok = unpackStruct(&dns.question[i], msg, off) + } + for i := 0; i < len(dns.answer); i++ { + dns.answer[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.ns); i++ { + dns.ns[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.extra); i++ { + dns.extra[i], off, ok = unpackRR(msg, off) + } + if !ok { + return false + } + // if off != len(msg) { + // println("extra bytes in dns packet", off, "<", len(msg)); + // } + return true +} + +func (dns *dnsMsg) String() string { + s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n" + if len(dns.question) > 0 { + s += "-- Questions\n" + for i := 0; i < len(dns.question); i++ { + s += printStruct(&dns.question[i]) + "\n" + } + } + if len(dns.answer) > 0 { + s += "-- Answers\n" + for i := 0; i < len(dns.answer); i++ { + s += printStruct(dns.answer[i]) + "\n" + } + } + if len(dns.ns) > 0 { + s += "-- Name servers\n" + for i := 0; i < len(dns.ns); i++ { + s += printStruct(dns.ns[i]) + "\n" + } + } + if len(dns.extra) > 0 { + s += "-- Extra\n" + for i := 0; i < len(dns.extra); i++ { + s += printStruct(dns.extra[i]) + "\n" + } + } + return s +} diff --git a/src/cmd/gofix/testdata/reflect.dnsmsg.go.out b/src/cmd/gofix/testdata/reflect.dnsmsg.go.out new file mode 100644 index 000000000..546e713a0 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.dnsmsg.go.out @@ -0,0 +1,779 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS packet assembly. See RFC 1035. +// +// This is intended to support name resolution during net.Dial. +// It doesn't have to be blazing fast. +// +// Rather than write the usual handful of routines to pack and +// unpack every message that can appear on the wire, we use +// reflection to write a generic pack/unpack for structs and then +// use it. Thus, if in the future we need to define new message +// structs, no new pack/unpack/printing code needs to be written. +// +// The first half of this file defines the DNS message formats. +// The second half implements the conversion to and from wire format. +// A few of the structure elements have string tags to aid the +// generic pack/unpack routines. +// +// TODO(rsc): There are enough names defined in this file that they're all +// prefixed with dns. Perhaps put this in its own package later. + +package net + +import ( + "fmt" + "os" + "reflect" +) + +// Packet formats + +// Wire constants. +const ( + // valid dnsRR_Header.Rrtype and dnsQuestion.qtype + dnsTypeA = 1 + dnsTypeNS = 2 + dnsTypeMD = 3 + dnsTypeMF = 4 + dnsTypeCNAME = 5 + dnsTypeSOA = 6 + dnsTypeMB = 7 + dnsTypeMG = 8 + dnsTypeMR = 9 + dnsTypeNULL = 10 + dnsTypeWKS = 11 + dnsTypePTR = 12 + dnsTypeHINFO = 13 + dnsTypeMINFO = 14 + dnsTypeMX = 15 + dnsTypeTXT = 16 + dnsTypeAAAA = 28 + dnsTypeSRV = 33 + + // valid dnsQuestion.qtype only + dnsTypeAXFR = 252 + dnsTypeMAILB = 253 + dnsTypeMAILA = 254 + dnsTypeALL = 255 + + // valid dnsQuestion.qclass + dnsClassINET = 1 + dnsClassCSNET = 2 + dnsClassCHAOS = 3 + dnsClassHESIOD = 4 + dnsClassANY = 255 + + // dnsMsg.rcode + dnsRcodeSuccess = 0 + dnsRcodeFormatError = 1 + dnsRcodeServerFailure = 2 + dnsRcodeNameError = 3 + dnsRcodeNotImplemented = 4 + dnsRcodeRefused = 5 +) + +// The wire format for the DNS packet header. +type dnsHeader struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +const ( + // dnsHeader.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available +) + +// DNS queries. +type dnsQuestion struct { + Name string "domain-name" // "domain-name" specifies encoding; see packers below + Qtype uint16 + Qclass uint16 +} + +// DNS responses (resource records). +// There are many types of messages, +// but they all share the same header. +type dnsRR_Header struct { + Name string "domain-name" + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *dnsRR_Header) Header() *dnsRR_Header { + return h +} + +type dnsRR interface { + Header() *dnsRR_Header +} + + +// Specific DNS RR formats for each query type. + +type dnsRR_CNAME struct { + Hdr dnsRR_Header + Cname string "domain-name" +} + +func (rr *dnsRR_CNAME) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_HINFO struct { + Hdr dnsRR_Header + Cpu string + Os string +} + +func (rr *dnsRR_HINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MB struct { + Hdr dnsRR_Header + Mb string "domain-name" +} + +func (rr *dnsRR_MB) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MG struct { + Hdr dnsRR_Header + Mg string "domain-name" +} + +func (rr *dnsRR_MG) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MINFO struct { + Hdr dnsRR_Header + Rmail string "domain-name" + Email string "domain-name" +} + +func (rr *dnsRR_MINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MR struct { + Hdr dnsRR_Header + Mr string "domain-name" +} + +func (rr *dnsRR_MR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MX struct { + Hdr dnsRR_Header + Pref uint16 + Mx string "domain-name" +} + +func (rr *dnsRR_MX) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_NS struct { + Hdr dnsRR_Header + Ns string "domain-name" +} + +func (rr *dnsRR_NS) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_PTR struct { + Hdr dnsRR_Header + Ptr string "domain-name" +} + +func (rr *dnsRR_PTR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SOA struct { + Hdr dnsRR_Header + Ns string "domain-name" + Mbox string "domain-name" + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *dnsRR_SOA) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_TXT struct { + Hdr dnsRR_Header + Txt string // not domain name +} + +func (rr *dnsRR_TXT) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SRV struct { + Hdr dnsRR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string "domain-name" +} + +func (rr *dnsRR_SRV) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_A struct { + Hdr dnsRR_Header + A uint32 "ipv4" +} + +func (rr *dnsRR_A) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_AAAA struct { + Hdr dnsRR_Header + AAAA [16]byte "ipv6" +} + +func (rr *dnsRR_AAAA) Header() *dnsRR_Header { + return &rr.Hdr +} + +// Packing and unpacking. +// +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +// Map of constructors for each RR wire type. +var rr_mk = map[int]func() dnsRR{ + dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) }, + dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) }, + dnsTypeMB: func() dnsRR { return new(dnsRR_MB) }, + dnsTypeMG: func() dnsRR { return new(dnsRR_MG) }, + dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) }, + dnsTypeMR: func() dnsRR { return new(dnsRR_MR) }, + dnsTypeMX: func() dnsRR { return new(dnsRR_MX) }, + dnsTypeNS: func() dnsRR { return new(dnsRR_NS) }, + dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) }, + dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) }, + dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) }, + dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) }, + dnsTypeA: func() dnsRR { return new(dnsRR_A) }, + dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) }, +} + +// Pack a domain name s into msg[off:]. +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. +func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { + // Add trailing dot to canonicalize name. + if n := len(s); n == 0 || s[n-1] != '.' { + s += "." + } + + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // There is also a trailing zero. + // Check that we have all the space we need. + tot := len(s) + 1 + if off+tot > len(msg) { + return len(msg), false + } + + // Emit sequence of counted strings, chopping at dots. + begin := 0 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + if i-begin >= 1<<6 { // top two bits of length must be clear + return len(msg), false + } + msg[off] = byte(i - begin) + off++ + for j := begin; j < i; j++ { + msg[off] = s[j] + off++ + } + begin = i + 1 + } + } + msg[off] = 0 + off++ + return off, true +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { + s = "" + ptr := 0 // number of pointers followed +Loop: + for { + if off >= len(msg) { + return "", len(msg), false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + break Loop + } + // literal string + if off+c > len(msg) { + return "", len(msg), false + } + s += string(msg[off:off+c]) + "." + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= len(msg) { + return "", len(msg), false + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", len(msg), false + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", len(msg), false + } + } + if ptr == 0 { + off1 = off + } + return s, off1, true +} + +// TODO(rsc): Move into generic library? +// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, +// [n]byte, and other (often anonymous) structs. +func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case reflect.Struct: + off, ok = packStructValue(fv, msg, off) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + i := fv.Uint() + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + } + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { + goto BadType + } + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(reflect.NewValue(msg[off:off+n]), fv) + off += n + case reflect.String: + // There are multiple string encodings. + // The tag distinguishes ordinary strings from domain names. + s := fv.String() + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "domain-name": + off, ok = packDomainName(s, msg, off) + if !ok { + return len(msg), false + } + case "": + // Counted string: 1 byte length. + if len(s) > 255 || off+1+len(s) > len(msg) { + return len(msg), false + } + msg[off] = byte(len(s)) + off++ + off += copy(msg[off:], s) + } + } + } + return off, true +} + +func structValue(any interface{}) reflect.Value { + return reflect.NewValue(any).Elem() +} + +func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = packStructValue(structValue(any), msg, off) + return off, ok +} + +// TODO(rsc): Move into generic library? +// Unpack a reflect.StructValue from msg. +// Same restrictions as packStructValue. +func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().Field(i) + switch fv := val.Field(i); fv.Kind() { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case reflect.Struct: + off, ok = unpackStructValue(fv, msg, off) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.SetUint(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.SetUint(uint64(i)) + off += 4 + } + case reflect.Array: + if fv.Type().Elem().Kind() != reflect.Uint8 { + goto BadType + } + n := fv.Len() + if off+n > len(msg) { + return len(msg), false + } + reflect.Copy(fv, reflect.NewValue(msg[off:off+n])) + off += n + case reflect.String: + var s string + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "domain-name": + s, off, ok = unpackDomainName(msg, off) + if !ok { + return len(msg), false + } + case "": + if off >= len(msg) || off+1+int(msg[off]) > len(msg) { + return len(msg), false + } + n := int(msg[off]) + off++ + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = msg[off+i] + } + off += n + s = string(b) + } + fv.SetString(s) + } + } + return off, true +} + +func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = unpackStructValue(structValue(any), msg, off) + return off, ok +} + +// Generic struct printer. +// Doesn't care about the string tag "domain-name", +// but does look for an "ipv4" tag on uint32 variables +// and the "ipv6" tag on array variables, +// printing them as IP addresses. +func printStructValue(val reflect.Value) string { + s := "{" + for i := 0; i < val.NumField(); i++ { + if i > 0 { + s += ", " + } + f := val.Type().Field(i) + if !f.Anonymous { + s += f.Name + "=" + } + fval := val.Field(i) + if fv := fval; fv.Kind() == reflect.Struct { + s += printStructValue(fv) + } else if fv := fval; (fv.Kind() == reflect.Uint || fv.Kind() == reflect.Uint8 || fv.Kind() == reflect.Uint16 || fv.Kind() == reflect.Uint32 || fv.Kind() == reflect.Uint64 || fv.Kind() == reflect.Uintptr) && f.Tag == "ipv4" { + i := fv.Uint() + s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + } else if fv := fval; fv.Kind() == reflect.Array && f.Tag == "ipv6" { + i := fv.Interface().([]byte) + s += IP(i).String() + } else { + s += fmt.Sprint(fval.Interface()) + } + } + s += "}" + return s +} + +func printStruct(any interface{}) string { return printStructValue(structValue(any)) } + +// Resource record packer. +func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) { + var off1 int + // pack twice, once to find end of header + // and again to find end of packet. + // a bit inefficient but this doesn't need to be fast. + // off1 is end of header + // off2 is end of rr + off1, ok = packStruct(rr.Header(), msg, off) + off2, ok = packStruct(rr, msg, off) + if !ok { + return len(msg), false + } + // pack a third time; redo header with correct data length + rr.Header().Rdlength = uint16(off2 - off1) + packStruct(rr.Header(), msg, off) + return off2, true +} + +// Resource record unpacker. +func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) { + // unpack just the header, to find the rr type and length + var h dnsRR_Header + off0 := off + if off, ok = unpackStruct(&h, msg, off); !ok { + return nil, len(msg), false + } + end := off + int(h.Rdlength) + + // make an rr of that type and re-unpack. + // again inefficient but doesn't need to be fast. + mk, known := rr_mk[int(h.Rrtype)] + if !known { + return &h, end, true + } + rr = mk() + off, ok = unpackStruct(rr, msg, off0) + if off != end { + return &h, end, true + } + return rr, off, ok +} + +// Usable representation of a DNS packet. + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type dnsMsgHdr struct { + id uint16 + response bool + opcode int + authoritative bool + truncated bool + recursion_desired bool + recursion_available bool + rcode int +} + +type dnsMsg struct { + dnsMsgHdr + question []dnsQuestion + answer []dnsRR + ns []dnsRR + extra []dnsRR +} + + +func (dns *dnsMsg) Pack() (msg []byte, ok bool) { + var dh dnsHeader + + // Convert convenient dnsMsg into wire-like dnsHeader. + dh.Id = dns.id + dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode) + if dns.recursion_available { + dh.Bits |= _RA + } + if dns.recursion_desired { + dh.Bits |= _RD + } + if dns.truncated { + dh.Bits |= _TC + } + if dns.authoritative { + dh.Bits |= _AA + } + if dns.response { + dh.Bits |= _QR + } + + // Prepare variable sized arrays. + question := dns.question + answer := dns.answer + ns := dns.ns + extra := dns.extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // Could work harder to calculate message size, + // but this is far more than we need and not + // big enough to hurt the allocator. + msg = make([]byte, 2000) + + // Pack it in: header and then the pieces. + off := 0 + off, ok = packStruct(&dh, msg, off) + for i := 0; i < len(question); i++ { + off, ok = packStruct(&question[i], msg, off) + } + for i := 0; i < len(answer); i++ { + off, ok = packRR(answer[i], msg, off) + } + for i := 0; i < len(ns); i++ { + off, ok = packRR(ns[i], msg, off) + } + for i := 0; i < len(extra); i++ { + off, ok = packRR(extra[i], msg, off) + } + if !ok { + return nil, false + } + return msg[0:off], true +} + +func (dns *dnsMsg) Unpack(msg []byte) bool { + // Header. + var dh dnsHeader + off := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return false + } + dns.id = dh.Id + dns.response = (dh.Bits & _QR) != 0 + dns.opcode = int(dh.Bits>>11) & 0xF + dns.authoritative = (dh.Bits & _AA) != 0 + dns.truncated = (dh.Bits & _TC) != 0 + dns.recursion_desired = (dh.Bits & _RD) != 0 + dns.recursion_available = (dh.Bits & _RA) != 0 + dns.rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.question = make([]dnsQuestion, dh.Qdcount) + dns.answer = make([]dnsRR, dh.Ancount) + dns.ns = make([]dnsRR, dh.Nscount) + dns.extra = make([]dnsRR, dh.Arcount) + + for i := 0; i < len(dns.question); i++ { + off, ok = unpackStruct(&dns.question[i], msg, off) + } + for i := 0; i < len(dns.answer); i++ { + dns.answer[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.ns); i++ { + dns.ns[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.extra); i++ { + dns.extra[i], off, ok = unpackRR(msg, off) + } + if !ok { + return false + } + // if off != len(msg) { + // println("extra bytes in dns packet", off, "<", len(msg)); + // } + return true +} + +func (dns *dnsMsg) String() string { + s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n" + if len(dns.question) > 0 { + s += "-- Questions\n" + for i := 0; i < len(dns.question); i++ { + s += printStruct(&dns.question[i]) + "\n" + } + } + if len(dns.answer) > 0 { + s += "-- Answers\n" + for i := 0; i < len(dns.answer); i++ { + s += printStruct(dns.answer[i]) + "\n" + } + } + if len(dns.ns) > 0 { + s += "-- Name servers\n" + for i := 0; i < len(dns.ns); i++ { + s += printStruct(dns.ns[i]) + "\n" + } + } + if len(dns.extra) > 0 { + s += "-- Extra\n" + for i := 0; i < len(dns.extra); i++ { + s += printStruct(dns.extra[i]) + "\n" + } + } + return s +} diff --git a/src/cmd/gofix/testdata/reflect.encode.go.in b/src/cmd/gofix/testdata/reflect.encode.go.in new file mode 100644 index 000000000..26ce47039 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.encode.go.in @@ -0,0 +1,367 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The json package implements encoding and decoding of JSON objects as +// defined in RFC 4627. +package json + +import ( + "bytes" + "encoding/base64" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "unicode" + "utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface, +// Marshal calls its MarshalJSON method to produce JSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point and integer values encode as JSON numbers. +// +// String values encode as JSON strings, with each invalid UTF-8 sequence +// replaced by the encoding of the Unicode replacement character U+FFFD. +// +// Array and slice values encode as JSON arrays, except that +// []byte encodes as a base64-encoded string. +// +// Struct values encode as JSON objects. Each struct field becomes +// a member of the object. By default the object's key name is the +// struct field name. If the struct field has a non-empty tag consisting +// of only Unicode letters, digits, and underscores, that tag will be used +// as the name instead. Only exported fields will be encoded. +// +// Map values encode as JSON objects. +// The map's key type must be string; the object keys are used directly +// as map keys. +// +// Pointer values encode as the value pointed to. +// A nil pointer encodes as the null JSON object. +// +// Interface values encode as the value contained in the interface. +// A nil interface value encodes as the null JSON object. +// +// Channel, complex, and function values cannot be encoded in JSON. +// Attempting to encode such a value causes Marshal to return +// an InvalidTypeError. +// +// JSON cannot represent cyclic data structures and Marshal does not +// handle them. Passing cyclic structures to Marshal will result in +// an infinite recursion. +// +func Marshal(v interface{}) ([]byte, os.Error) { + e := &encodeState{} + err := e.marshal(v) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// MarshalIndent is like Marshal but applies Indent to format the output. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, os.Error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + err = Indent(&buf, b, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// MarshalForHTML is like Marshal but applies HTMLEscape to the output. +func MarshalForHTML(v interface{}) ([]byte, os.Error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + HTMLEscape(&buf, b) + return buf.Bytes(), nil +} + +// HTMLEscape appends to dst the JSON-encoded src with <, >, and & +// characters inside string literals changed to \u003c, \u003e, \u0026 +// so that the JSON will be safe to embed inside HTML <script> tags. +// For historical reasons, web browsers don't honor standard HTML +// escaping within <script> tags, so an alternative JSON encoding must +// be used. +func HTMLEscape(dst *bytes.Buffer, src []byte) { + // < > & can only appear in string literals, + // so just scan the string one byte at a time. + start := 0 + for i, c := range src { + if c == '<' || c == '>' || c == '&' { + if start < i { + dst.Write(src[start:i]) + } + dst.WriteString(`\u00`) + dst.WriteByte(hex[c>>4]) + dst.WriteByte(hex[c&0xF]) + start = i + 1 + } + } + if start < len(src) { + dst.Write(src[start:]) + } +} + +// Marshaler is the interface implemented by objects that +// can marshal themselves into valid JSON. +type Marshaler interface { + MarshalJSON() ([]byte, os.Error) +} + +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) String() string { + return "json: unsupported type: " + e.Type.String() +} + +type InvalidUTF8Error struct { + S string +} + +func (e *InvalidUTF8Error) String() string { + return "json: invalid UTF-8 in string: " + strconv.Quote(e.S) +} + +type MarshalerError struct { + Type reflect.Type + Error os.Error +} + +func (e *MarshalerError) String() string { + return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Error.String() +} + +type interfaceOrPtrValue interface { + IsNil() bool + Elem() reflect.Value +} + +var hex = "0123456789abcdef" + +// An encodeState encodes JSON into a bytes.Buffer. +type encodeState struct { + bytes.Buffer // accumulated output +} + +func (e *encodeState) marshal(v interface{}) (err os.Error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(os.Error) + } + }() + e.reflectValue(reflect.NewValue(v)) + return nil +} + +func (e *encodeState) error(err os.Error) { + panic(err) +} + +var byteSliceType = reflect.Typeof([]byte(nil)) + +func (e *encodeState) reflectValue(v reflect.Value) { + if v == nil { + e.WriteString("null") + return + } + + if j, ok := v.Interface().(Marshaler); ok { + b, err := j.MarshalJSON() + if err == nil { + // copy JSON into buffer, checking validity. + err = Compact(&e.Buffer, b) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } + return + } + + switch v := v.(type) { + case *reflect.BoolValue: + x := v.Get() + if x { + e.WriteString("true") + } else { + e.WriteString("false") + } + + case *reflect.IntValue: + e.WriteString(strconv.Itoa64(v.Get())) + + case *reflect.UintValue: + e.WriteString(strconv.Uitoa64(v.Get())) + + case *reflect.FloatValue: + e.WriteString(strconv.FtoaN(v.Get(), 'g', -1, v.Type().Bits())) + + case *reflect.StringValue: + e.string(v.Get()) + + case *reflect.StructValue: + e.WriteByte('{') + t := v.Type().(*reflect.StructType) + n := v.NumField() + first := true + for i := 0; i < n; i++ { + f := t.Field(i) + if f.PkgPath != "" { + continue + } + if first { + first = false + } else { + e.WriteByte(',') + } + if isValidTag(f.Tag) { + e.string(f.Tag) + } else { + e.string(f.Name) + } + e.WriteByte(':') + e.reflectValue(v.Field(i)) + } + e.WriteByte('}') + + case *reflect.MapValue: + if _, ok := v.Type().(*reflect.MapType).Key().(*reflect.StringType); !ok { + e.error(&UnsupportedTypeError{v.Type()}) + } + if v.IsNil() { + e.WriteString("null") + break + } + e.WriteByte('{') + var sv stringValues = v.Keys() + sort.Sort(sv) + for i, k := range sv { + if i > 0 { + e.WriteByte(',') + } + e.string(k.(*reflect.StringValue).Get()) + e.WriteByte(':') + e.reflectValue(v.Elem(k)) + } + e.WriteByte('}') + + case reflect.ArrayOrSliceValue: + if v.Type() == byteSliceType { + e.WriteByte('"') + s := v.Interface().([]byte) + if len(s) < 1024 { + // for small buffers, using Encode directly is much faster. + dst := make([]byte, base64.StdEncoding.EncodedLen(len(s))) + base64.StdEncoding.Encode(dst, s) + e.Write(dst) + } else { + // for large buffers, avoid unnecessary extra temporary + // buffer space. + enc := base64.NewEncoder(base64.StdEncoding, e) + enc.Write(s) + enc.Close() + } + e.WriteByte('"') + break + } + e.WriteByte('[') + n := v.Len() + for i := 0; i < n; i++ { + if i > 0 { + e.WriteByte(',') + } + e.reflectValue(v.Elem(i)) + } + e.WriteByte(']') + + case interfaceOrPtrValue: + if v.IsNil() { + e.WriteString("null") + return + } + e.reflectValue(v.Elem()) + + default: + e.error(&UnsupportedTypeError{v.Type()}) + } + return +} + +func isValidTag(s string) bool { + if s == "" { + return false + } + for _, c := range s { + if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) { + return false + } + } + return true +} + +// stringValues is a slice of reflect.Value holding *reflect.StringValue. +// It implements the methods to sort by string. +type stringValues []reflect.Value + +func (sv stringValues) Len() int { return len(sv) } +func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } +func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } +func (sv stringValues) get(i int) string { return sv[i].(*reflect.StringValue).Get() } + +func (e *encodeState) string(s string) { + e.WriteByte('"') + start := 0 + for i := 0; i < len(s); { + if b := s[i]; b < utf8.RuneSelf { + if 0x20 <= b && b != '\\' && b != '"' { + i++ + continue + } + if start < i { + e.WriteString(s[start:i]) + } + if b == '\\' || b == '"' { + e.WriteByte('\\') + e.WriteByte(b) + } else { + e.WriteString(`\u00`) + e.WriteByte(hex[b>>4]) + e.WriteByte(hex[b&0xF]) + } + i++ + start = i + continue + } + c, size := utf8.DecodeRuneInString(s[i:]) + if c == utf8.RuneError && size == 1 { + e.error(&InvalidUTF8Error{s}) + } + i += size + } + if start < len(s) { + e.WriteString(s[start:]) + } + e.WriteByte('"') +} diff --git a/src/cmd/gofix/testdata/reflect.encode.go.out b/src/cmd/gofix/testdata/reflect.encode.go.out new file mode 100644 index 000000000..8c79a27d4 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.encode.go.out @@ -0,0 +1,367 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The json package implements encoding and decoding of JSON objects as +// defined in RFC 4627. +package json + +import ( + "bytes" + "encoding/base64" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "unicode" + "utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface, +// Marshal calls its MarshalJSON method to produce JSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point and integer values encode as JSON numbers. +// +// String values encode as JSON strings, with each invalid UTF-8 sequence +// replaced by the encoding of the Unicode replacement character U+FFFD. +// +// Array and slice values encode as JSON arrays, except that +// []byte encodes as a base64-encoded string. +// +// Struct values encode as JSON objects. Each struct field becomes +// a member of the object. By default the object's key name is the +// struct field name. If the struct field has a non-empty tag consisting +// of only Unicode letters, digits, and underscores, that tag will be used +// as the name instead. Only exported fields will be encoded. +// +// Map values encode as JSON objects. +// The map's key type must be string; the object keys are used directly +// as map keys. +// +// Pointer values encode as the value pointed to. +// A nil pointer encodes as the null JSON object. +// +// Interface values encode as the value contained in the interface. +// A nil interface value encodes as the null JSON object. +// +// Channel, complex, and function values cannot be encoded in JSON. +// Attempting to encode such a value causes Marshal to return +// an InvalidTypeError. +// +// JSON cannot represent cyclic data structures and Marshal does not +// handle them. Passing cyclic structures to Marshal will result in +// an infinite recursion. +// +func Marshal(v interface{}) ([]byte, os.Error) { + e := &encodeState{} + err := e.marshal(v) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// MarshalIndent is like Marshal but applies Indent to format the output. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, os.Error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + err = Indent(&buf, b, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// MarshalForHTML is like Marshal but applies HTMLEscape to the output. +func MarshalForHTML(v interface{}) ([]byte, os.Error) { + b, err := Marshal(v) + if err != nil { + return nil, err + } + var buf bytes.Buffer + HTMLEscape(&buf, b) + return buf.Bytes(), nil +} + +// HTMLEscape appends to dst the JSON-encoded src with <, >, and & +// characters inside string literals changed to \u003c, \u003e, \u0026 +// so that the JSON will be safe to embed inside HTML <script> tags. +// For historical reasons, web browsers don't honor standard HTML +// escaping within <script> tags, so an alternative JSON encoding must +// be used. +func HTMLEscape(dst *bytes.Buffer, src []byte) { + // < > & can only appear in string literals, + // so just scan the string one byte at a time. + start := 0 + for i, c := range src { + if c == '<' || c == '>' || c == '&' { + if start < i { + dst.Write(src[start:i]) + } + dst.WriteString(`\u00`) + dst.WriteByte(hex[c>>4]) + dst.WriteByte(hex[c&0xF]) + start = i + 1 + } + } + if start < len(src) { + dst.Write(src[start:]) + } +} + +// Marshaler is the interface implemented by objects that +// can marshal themselves into valid JSON. +type Marshaler interface { + MarshalJSON() ([]byte, os.Error) +} + +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) String() string { + return "json: unsupported type: " + e.Type.String() +} + +type InvalidUTF8Error struct { + S string +} + +func (e *InvalidUTF8Error) String() string { + return "json: invalid UTF-8 in string: " + strconv.Quote(e.S) +} + +type MarshalerError struct { + Type reflect.Type + Error os.Error +} + +func (e *MarshalerError) String() string { + return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Error.String() +} + +type interfaceOrPtrValue interface { + IsNil() bool + Elem() reflect.Value +} + +var hex = "0123456789abcdef" + +// An encodeState encodes JSON into a bytes.Buffer. +type encodeState struct { + bytes.Buffer // accumulated output +} + +func (e *encodeState) marshal(v interface{}) (err os.Error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(os.Error) + } + }() + e.reflectValue(reflect.NewValue(v)) + return nil +} + +func (e *encodeState) error(err os.Error) { + panic(err) +} + +var byteSliceType = reflect.Typeof([]byte(nil)) + +func (e *encodeState) reflectValue(v reflect.Value) { + if !v.IsValid() { + e.WriteString("null") + return + } + + if j, ok := v.Interface().(Marshaler); ok { + b, err := j.MarshalJSON() + if err == nil { + // copy JSON into buffer, checking validity. + err = Compact(&e.Buffer, b) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } + return + } + + switch v.Kind() { + case reflect.Bool: + x := v.Bool() + if x { + e.WriteString("true") + } else { + e.WriteString("false") + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + e.WriteString(strconv.Itoa64(v.Int())) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + e.WriteString(strconv.Uitoa64(v.Uint())) + + case reflect.Float32, reflect.Float64: + e.WriteString(strconv.FtoaN(v.Float(), 'g', -1, v.Type().Bits())) + + case reflect.String: + e.string(v.String()) + + case reflect.Struct: + e.WriteByte('{') + t := v.Type() + n := v.NumField() + first := true + for i := 0; i < n; i++ { + f := t.Field(i) + if f.PkgPath != "" { + continue + } + if first { + first = false + } else { + e.WriteByte(',') + } + if isValidTag(f.Tag) { + e.string(f.Tag) + } else { + e.string(f.Name) + } + e.WriteByte(':') + e.reflectValue(v.Field(i)) + } + e.WriteByte('}') + + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + e.error(&UnsupportedTypeError{v.Type()}) + } + if v.IsNil() { + e.WriteString("null") + break + } + e.WriteByte('{') + var sv stringValues = v.MapKeys() + sort.Sort(sv) + for i, k := range sv { + if i > 0 { + e.WriteByte(',') + } + e.string(k.String()) + e.WriteByte(':') + e.reflectValue(v.MapIndex(k)) + } + e.WriteByte('}') + + case reflect.Array, reflect.Slice: + if v.Type() == byteSliceType { + e.WriteByte('"') + s := v.Interface().([]byte) + if len(s) < 1024 { + // for small buffers, using Encode directly is much faster. + dst := make([]byte, base64.StdEncoding.EncodedLen(len(s))) + base64.StdEncoding.Encode(dst, s) + e.Write(dst) + } else { + // for large buffers, avoid unnecessary extra temporary + // buffer space. + enc := base64.NewEncoder(base64.StdEncoding, e) + enc.Write(s) + enc.Close() + } + e.WriteByte('"') + break + } + e.WriteByte('[') + n := v.Len() + for i := 0; i < n; i++ { + if i > 0 { + e.WriteByte(',') + } + e.reflectValue(v.Index(i)) + } + e.WriteByte(']') + + case interfaceOrPtrValue: + if v.IsNil() { + e.WriteString("null") + return + } + e.reflectValue(v.Elem()) + + default: + e.error(&UnsupportedTypeError{v.Type()}) + } + return +} + +func isValidTag(s string) bool { + if s == "" { + return false + } + for _, c := range s { + if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) { + return false + } + } + return true +} + +// stringValues is a slice of reflect.Value holding *reflect.StringValue. +// It implements the methods to sort by string. +type stringValues []reflect.Value + +func (sv stringValues) Len() int { return len(sv) } +func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } +func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } +func (sv stringValues) get(i int) string { return sv[i].String() } + +func (e *encodeState) string(s string) { + e.WriteByte('"') + start := 0 + for i := 0; i < len(s); { + if b := s[i]; b < utf8.RuneSelf { + if 0x20 <= b && b != '\\' && b != '"' { + i++ + continue + } + if start < i { + e.WriteString(s[start:i]) + } + if b == '\\' || b == '"' { + e.WriteByte('\\') + e.WriteByte(b) + } else { + e.WriteString(`\u00`) + e.WriteByte(hex[b>>4]) + e.WriteByte(hex[b&0xF]) + } + i++ + start = i + continue + } + c, size := utf8.DecodeRuneInString(s[i:]) + if c == utf8.RuneError && size == 1 { + e.error(&InvalidUTF8Error{s}) + } + i += size + } + if start < len(s) { + e.WriteString(s[start:]) + } + e.WriteByte('"') +} diff --git a/src/cmd/gofix/testdata/reflect.encoder.go.in b/src/cmd/gofix/testdata/reflect.encoder.go.in new file mode 100644 index 000000000..e52a4de29 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.encoder.go.in @@ -0,0 +1,240 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bytes" + "io" + "os" + "reflect" + "sync" +) + +// An Encoder manages the transmission of type and data information to the +// other side of a connection. +type Encoder struct { + mutex sync.Mutex // each item must be sent atomically + w []io.Writer // where to send the data + sent map[reflect.Type]typeId // which types we've already sent + countState *encoderState // stage for writing counts + freeList *encoderState // list of free encoderStates; avoids reallocation + buf []byte // for collecting the output. + byteBuf bytes.Buffer // buffer for top-level encoderState + err os.Error +} + +// NewEncoder returns a new encoder that will transmit on the io.Writer. +func NewEncoder(w io.Writer) *Encoder { + enc := new(Encoder) + enc.w = []io.Writer{w} + enc.sent = make(map[reflect.Type]typeId) + enc.countState = enc.newEncoderState(new(bytes.Buffer)) + return enc +} + +// writer() returns the innermost writer the encoder is using +func (enc *Encoder) writer() io.Writer { + return enc.w[len(enc.w)-1] +} + +// pushWriter adds a writer to the encoder. +func (enc *Encoder) pushWriter(w io.Writer) { + enc.w = append(enc.w, w) +} + +// popWriter pops the innermost writer. +func (enc *Encoder) popWriter() { + enc.w = enc.w[0 : len(enc.w)-1] +} + +func (enc *Encoder) badType(rt reflect.Type) { + enc.setError(os.ErrorString("gob: can't encode type " + rt.String())) +} + +func (enc *Encoder) setError(err os.Error) { + if enc.err == nil { // remember the first. + enc.err = err + } +} + +// writeMessage sends the data item preceded by a unsigned count of its length. +func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { + enc.countState.encodeUint(uint64(b.Len())) + // Build the buffer. + countLen := enc.countState.b.Len() + total := countLen + b.Len() + if total > len(enc.buf) { + enc.buf = make([]byte, total+1000) // extra for growth + } + // Place the length before the data. + // TODO(r): avoid the extra copy here. + enc.countState.b.Read(enc.buf[0:countLen]) + // Now the data. + b.Read(enc.buf[countLen:total]) + // Write the data. + _, err := w.Write(enc.buf[0:total]) + if err != nil { + enc.setError(err) + } +} + +// sendActualType sends the requested type, without further investigation, unless +// it's been sent before. +func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) { + if _, alreadySent := enc.sent[actual]; alreadySent { + return false + } + typeLock.Lock() + info, err := getTypeInfo(ut) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return + } + // Send the pair (-id, type) + // Id: + state.encodeInt(-int64(info.id)) + // Type: + enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.writeMessage(w, state.b) + if enc.err != nil { + return + } + + // Remember we've sent this type, both what the user gave us and the base type. + enc.sent[ut.base] = info.id + if ut.user != ut.base { + enc.sent[ut.user] = info.id + } + // Now send the inner types + switch st := actual.(type) { + case *reflect.StructType: + for i := 0; i < st.NumField(); i++ { + enc.sendType(w, state, st.Field(i).Type) + } + case reflect.ArrayOrSliceType: + enc.sendType(w, state, st.Elem()) + } + return true +} + +// sendType sends the type info to the other side, if necessary. +func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { + ut := userType(origt) + if ut.isGobEncoder { + // The rules are different: regardless of the underlying type's representation, + // we need to tell the other side that this exact type is a GobEncoder. + return enc.sendActualType(w, state, ut, ut.user) + } + + // It's a concrete value, so drill down to the base type. + switch rt := ut.base.(type) { + default: + // Basic types and interfaces do not need to be described. + return + case *reflect.SliceType: + // If it's []uint8, don't send; it's considered basic. + if rt.Elem().Kind() == reflect.Uint8 { + return + } + // Otherwise we do send. + break + case *reflect.ArrayType: + // arrays must be sent so we know their lengths and element types. + break + case *reflect.MapType: + // maps must be sent so we know their lengths and key/value types. + break + case *reflect.StructType: + // structs must be sent so we know their fields. + break + case *reflect.ChanType, *reflect.FuncType: + // Probably a bad field in a struct. + enc.badType(rt) + return + } + + return enc.sendActualType(w, state, ut, ut.base) +} + +// Encode transmits the data item represented by the empty interface value, +// guaranteeing that all necessary type information has been transmitted first. +func (enc *Encoder) Encode(e interface{}) os.Error { + return enc.EncodeValue(reflect.NewValue(e)) +} + +// sendTypeDescriptor makes sure the remote side knows about this type. +// It will send a descriptor if this is the first time the type has been +// sent. +func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { + // Make sure the type is known to the other side. + // First, have we already sent this type? + rt := ut.base + if ut.isGobEncoder { + rt = ut.user + } + if _, alreadySent := enc.sent[rt]; !alreadySent { + // No, so send it. + sent := enc.sendType(w, state, rt) + if enc.err != nil { + return + } + // If the type info has still not been transmitted, it means we have + // a singleton basic type (int, []byte etc.) at top level. We don't + // need to send the type info but we do need to update enc.sent. + if !sent { + typeLock.Lock() + info, err := getTypeInfo(ut) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return + } + enc.sent[rt] = info.id + } + } +} + +// sendTypeId sends the id, which must have already been defined. +func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) { + // Identify the type of this top-level value. + state.encodeInt(int64(enc.sent[ut.base])) +} + +// EncodeValue transmits the data item represented by the reflection value, +// guaranteeing that all necessary type information has been transmitted first. +func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { + // Make sure we're single-threaded through here, so multiple + // goroutines can share an encoder. + enc.mutex.Lock() + defer enc.mutex.Unlock() + + // Remove any nested writers remaining due to previous errors. + enc.w = enc.w[0:1] + + ut, err := validUserType(value.Type()) + if err != nil { + return err + } + + enc.err = nil + enc.byteBuf.Reset() + state := enc.newEncoderState(&enc.byteBuf) + + enc.sendTypeDescriptor(enc.writer(), state, ut) + enc.sendTypeId(state, ut) + if enc.err != nil { + return enc.err + } + + // Encode the object. + enc.encode(state.b, value, ut) + if enc.err == nil { + enc.writeMessage(enc.writer(), state.b) + } + + enc.freeEncoderState(state) + return enc.err +} diff --git a/src/cmd/gofix/testdata/reflect.encoder.go.out b/src/cmd/gofix/testdata/reflect.encoder.go.out new file mode 100644 index 000000000..928f3b244 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.encoder.go.out @@ -0,0 +1,240 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "bytes" + "io" + "os" + "reflect" + "sync" +) + +// An Encoder manages the transmission of type and data information to the +// other side of a connection. +type Encoder struct { + mutex sync.Mutex // each item must be sent atomically + w []io.Writer // where to send the data + sent map[reflect.Type]typeId // which types we've already sent + countState *encoderState // stage for writing counts + freeList *encoderState // list of free encoderStates; avoids reallocation + buf []byte // for collecting the output. + byteBuf bytes.Buffer // buffer for top-level encoderState + err os.Error +} + +// NewEncoder returns a new encoder that will transmit on the io.Writer. +func NewEncoder(w io.Writer) *Encoder { + enc := new(Encoder) + enc.w = []io.Writer{w} + enc.sent = make(map[reflect.Type]typeId) + enc.countState = enc.newEncoderState(new(bytes.Buffer)) + return enc +} + +// writer() returns the innermost writer the encoder is using +func (enc *Encoder) writer() io.Writer { + return enc.w[len(enc.w)-1] +} + +// pushWriter adds a writer to the encoder. +func (enc *Encoder) pushWriter(w io.Writer) { + enc.w = append(enc.w, w) +} + +// popWriter pops the innermost writer. +func (enc *Encoder) popWriter() { + enc.w = enc.w[0 : len(enc.w)-1] +} + +func (enc *Encoder) badType(rt reflect.Type) { + enc.setError(os.ErrorString("gob: can't encode type " + rt.String())) +} + +func (enc *Encoder) setError(err os.Error) { + if enc.err == nil { // remember the first. + enc.err = err + } +} + +// writeMessage sends the data item preceded by a unsigned count of its length. +func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { + enc.countState.encodeUint(uint64(b.Len())) + // Build the buffer. + countLen := enc.countState.b.Len() + total := countLen + b.Len() + if total > len(enc.buf) { + enc.buf = make([]byte, total+1000) // extra for growth + } + // Place the length before the data. + // TODO(r): avoid the extra copy here. + enc.countState.b.Read(enc.buf[0:countLen]) + // Now the data. + b.Read(enc.buf[countLen:total]) + // Write the data. + _, err := w.Write(enc.buf[0:total]) + if err != nil { + enc.setError(err) + } +} + +// sendActualType sends the requested type, without further investigation, unless +// it's been sent before. +func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) { + if _, alreadySent := enc.sent[actual]; alreadySent { + return false + } + typeLock.Lock() + info, err := getTypeInfo(ut) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return + } + // Send the pair (-id, type) + // Id: + state.encodeInt(-int64(info.id)) + // Type: + enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.writeMessage(w, state.b) + if enc.err != nil { + return + } + + // Remember we've sent this type, both what the user gave us and the base type. + enc.sent[ut.base] = info.id + if ut.user != ut.base { + enc.sent[ut.user] = info.id + } + // Now send the inner types + switch st := actual; st.Kind() { + case reflect.Struct: + for i := 0; i < st.NumField(); i++ { + enc.sendType(w, state, st.Field(i).Type) + } + case reflect.Array, reflect.Slice: + enc.sendType(w, state, st.Elem()) + } + return true +} + +// sendType sends the type info to the other side, if necessary. +func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { + ut := userType(origt) + if ut.isGobEncoder { + // The rules are different: regardless of the underlying type's representation, + // we need to tell the other side that this exact type is a GobEncoder. + return enc.sendActualType(w, state, ut, ut.user) + } + + // It's a concrete value, so drill down to the base type. + switch rt := ut.base; rt.Kind() { + default: + // Basic types and interfaces do not need to be described. + return + case reflect.Slice: + // If it's []uint8, don't send; it's considered basic. + if rt.Elem().Kind() == reflect.Uint8 { + return + } + // Otherwise we do send. + break + case reflect.Array: + // arrays must be sent so we know their lengths and element types. + break + case reflect.Map: + // maps must be sent so we know their lengths and key/value types. + break + case reflect.Struct: + // structs must be sent so we know their fields. + break + case reflect.Chan, reflect.Func: + // Probably a bad field in a struct. + enc.badType(rt) + return + } + + return enc.sendActualType(w, state, ut, ut.base) +} + +// Encode transmits the data item represented by the empty interface value, +// guaranteeing that all necessary type information has been transmitted first. +func (enc *Encoder) Encode(e interface{}) os.Error { + return enc.EncodeValue(reflect.NewValue(e)) +} + +// sendTypeDescriptor makes sure the remote side knows about this type. +// It will send a descriptor if this is the first time the type has been +// sent. +func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { + // Make sure the type is known to the other side. + // First, have we already sent this type? + rt := ut.base + if ut.isGobEncoder { + rt = ut.user + } + if _, alreadySent := enc.sent[rt]; !alreadySent { + // No, so send it. + sent := enc.sendType(w, state, rt) + if enc.err != nil { + return + } + // If the type info has still not been transmitted, it means we have + // a singleton basic type (int, []byte etc.) at top level. We don't + // need to send the type info but we do need to update enc.sent. + if !sent { + typeLock.Lock() + info, err := getTypeInfo(ut) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return + } + enc.sent[rt] = info.id + } + } +} + +// sendTypeId sends the id, which must have already been defined. +func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) { + // Identify the type of this top-level value. + state.encodeInt(int64(enc.sent[ut.base])) +} + +// EncodeValue transmits the data item represented by the reflection value, +// guaranteeing that all necessary type information has been transmitted first. +func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { + // Make sure we're single-threaded through here, so multiple + // goroutines can share an encoder. + enc.mutex.Lock() + defer enc.mutex.Unlock() + + // Remove any nested writers remaining due to previous errors. + enc.w = enc.w[0:1] + + ut, err := validUserType(value.Type()) + if err != nil { + return err + } + + enc.err = nil + enc.byteBuf.Reset() + state := enc.newEncoderState(&enc.byteBuf) + + enc.sendTypeDescriptor(enc.writer(), state, ut) + enc.sendTypeId(state, ut) + if enc.err != nil { + return enc.err + } + + // Encode the object. + enc.encode(state.b, value, ut) + if enc.err == nil { + enc.writeMessage(enc.writer(), state.b) + } + + enc.freeEncoderState(state) + return enc.err +} diff --git a/src/cmd/gofix/testdata/reflect.export.go.in b/src/cmd/gofix/testdata/reflect.export.go.in new file mode 100644 index 000000000..e91e777e3 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.export.go.in @@ -0,0 +1,400 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + The netchan package implements type-safe networked channels: + it allows the two ends of a channel to appear on different + computers connected by a network. It does this by transporting + data sent to a channel on one machine so it can be recovered + by a receive of a channel of the same type on the other. + + An exporter publishes a set of channels by name. An importer + connects to the exporting machine and imports the channels + by name. After importing the channels, the two machines can + use the channels in the usual way. + + Networked channels are not synchronized; they always behave + as if they are buffered channels of at least one element. +*/ +package netchan + +// BUG: can't use range clause to receive when using ImportNValues to limit the count. + +import ( + "log" + "io" + "net" + "os" + "reflect" + "strconv" + "sync" +) + +// Export + +// expLog is a logging convenience function. The first argument must be a string. +func expLog(args ...interface{}) { + args[0] = "netchan export: " + args[0].(string) + log.Print(args...) +} + +// An Exporter allows a set of channels to be published on a single +// network port. A single machine may have multiple Exporters +// but they must use different ports. +type Exporter struct { + *clientSet +} + +type expClient struct { + *encDec + exp *Exporter + chans map[int]*netChan // channels in use by client + mu sync.Mutex // protects remaining fields + errored bool // client has been sent an error + seqNum int64 // sequences messages sent to client; has value of highest sent + ackNum int64 // highest sequence number acknowledged + seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu +} + +func newClient(exp *Exporter, conn io.ReadWriter) *expClient { + client := new(expClient) + client.exp = exp + client.encDec = newEncDec(conn) + client.seqNum = 0 + client.ackNum = 0 + client.chans = make(map[int]*netChan) + return client +} + +func (client *expClient) sendError(hdr *header, err string) { + error := &error{err} + expLog("sending error to client:", error.Error) + client.encode(hdr, payError, error) // ignore any encode error, hope client gets it + client.mu.Lock() + client.errored = true + client.mu.Unlock() +} + +func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan { + exp := client.exp + exp.mu.Lock() + ech, ok := exp.names[name] + exp.mu.Unlock() + if !ok { + client.sendError(hdr, "no such channel: "+name) + return nil + } + if ech.dir != dir { + client.sendError(hdr, "wrong direction for channel: "+name) + return nil + } + nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count) + client.chans[hdr.Id] = nch + return nch +} + +func (client *expClient) getChan(hdr *header, dir Dir) *netChan { + nch := client.chans[hdr.Id] + if nch == nil { + return nil + } + if nch.dir != dir { + client.sendError(hdr, "wrong direction for channel: "+nch.name) + } + return nch +} + +// The function run manages sends and receives for a single client. For each +// (client Recv) request, this will launch a serveRecv goroutine to deliver +// the data for that channel, while (client Send) requests are handled as +// data arrives from the client. +func (client *expClient) run() { + hdr := new(header) + hdrValue := reflect.NewValue(hdr) + req := new(request) + reqValue := reflect.NewValue(req) + error := new(error) + for { + *hdr = header{} + if err := client.decode(hdrValue); err != nil { + if err != os.EOF { + expLog("error decoding client header:", err) + } + break + } + switch hdr.PayloadType { + case payRequest: + *req = request{} + if err := client.decode(reqValue); err != nil { + expLog("error decoding client request:", err) + break + } + if req.Size < 1 { + panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values") + } + switch req.Dir { + case Recv: + // look up channel before calling serveRecv to + // avoid a lock around client.chans. + if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil { + go client.serveRecv(nch, *hdr, req.Count) + } + case Send: + client.newChan(hdr, Recv, req.Name, req.Size, req.Count) + // The actual sends will have payload type payData. + // TODO: manage the count? + default: + error.Error = "request: can't handle channel direction" + expLog(error.Error, req.Dir) + client.encode(hdr, payError, error) + } + case payData: + client.serveSend(*hdr) + case payClosed: + client.serveClosed(*hdr) + case payAck: + client.mu.Lock() + if client.ackNum != hdr.SeqNum-1 { + // Since the sequence number is incremented and the message is sent + // in a single instance of locking client.mu, the messages are guaranteed + // to be sent in order. Therefore receipt of acknowledgement N means + // all messages <=N have been seen by the recipient. We check anyway. + expLog("sequence out of order:", client.ackNum, hdr.SeqNum) + } + if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count. + client.ackNum = hdr.SeqNum + } + client.mu.Unlock() + case payAckSend: + if nch := client.getChan(hdr, Send); nch != nil { + nch.acked() + } + default: + log.Fatal("netchan export: unknown payload type", hdr.PayloadType) + } + } + client.exp.delClient(client) +} + +// Send all the data on a single channel to a client asking for a Recv. +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) { + for { + val, ok := nch.recv() + if !ok { + if err := client.encode(&hdr, payClosed, nil); err != nil { + expLog("error encoding server closed message:", err) + } + break + } + // We hold the lock during transmission to guarantee messages are + // sent in sequence number order. Also, we increment first so the + // value of client.SeqNum is the value of the highest used sequence + // number, not one beyond. + client.mu.Lock() + client.seqNum++ + hdr.SeqNum = client.seqNum + client.seqLock.Lock() // guarantee ordering of messages + client.mu.Unlock() + err := client.encode(&hdr, payData, val.Interface()) + client.seqLock.Unlock() + if err != nil { + expLog("error encoding client response:", err) + client.sendError(&hdr, err.String()) + break + } + // Negative count means run forever. + if count >= 0 { + if count--; count <= 0 { + break + } + } + } +} + +// Receive and deliver locally one item from a client asking for a Send +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveSend(hdr header) { + nch := client.getChan(&hdr, Recv) + if nch == nil { + return + } + // Create a new value for each received item. + val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem()) + if err := client.decode(val); err != nil { + expLog("value decode:", err, "; type ", nch.ch.Type()) + return + } + nch.send(val) +} + +// Report that client has closed the channel that is sending to us. +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveClosed(hdr header) { + nch := client.getChan(&hdr, Recv) + if nch == nil { + return + } + nch.close() +} + +func (client *expClient) unackedCount() int64 { + client.mu.Lock() + n := client.seqNum - client.ackNum + client.mu.Unlock() + return n +} + +func (client *expClient) seq() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +func (client *expClient) ack() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +// Serve waits for incoming connections on the listener +// and serves the Exporter's channels on each. +// It blocks until the listener is closed. +func (exp *Exporter) Serve(listener net.Listener) { + for { + conn, err := listener.Accept() + if err != nil { + expLog("listen:", err) + break + } + go exp.ServeConn(conn) + } +} + +// ServeConn exports the Exporter's channels on conn. +// It blocks until the connection is terminated. +func (exp *Exporter) ServeConn(conn io.ReadWriter) { + exp.addClient(conn).run() +} + +// NewExporter creates a new Exporter that exports a set of channels. +func NewExporter() *Exporter { + e := &Exporter{ + clientSet: &clientSet{ + names: make(map[string]*chanDir), + clients: make(map[unackedCounter]bool), + }, + } + return e +} + +// ListenAndServe exports the exporter's channels through the +// given network and local address defined as in net.Listen. +func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error { + listener, err := net.Listen(network, localaddr) + if err != nil { + return err + } + go exp.Serve(listener) + return nil +} + +// addClient creates a new expClient and records its existence +func (exp *Exporter) addClient(conn io.ReadWriter) *expClient { + client := newClient(exp, conn) + exp.mu.Lock() + exp.clients[client] = true + exp.mu.Unlock() + return client +} + +// delClient forgets the client existed +func (exp *Exporter) delClient(client *expClient) { + exp.mu.Lock() + exp.clients[client] = false, false + exp.mu.Unlock() +} + +// Drain waits until all messages sent from this exporter/importer, including +// those not yet sent to any client and possibly including those sent while +// Drain was executing, have been received by the importer. In short, it +// waits until all the exporter's messages have been received by a client. +// If the timeout (measured in nanoseconds) is positive and Drain takes +// longer than that to complete, an error is returned. +func (exp *Exporter) Drain(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.drain(timeout) +} + +// Sync waits until all clients of the exporter have received the messages +// that were sent at the time Sync was invoked. Unlike Drain, it does not +// wait for messages sent while it is running or messages that have not been +// dispatched to any client. If the timeout (measured in nanoseconds) is +// positive and Sync takes longer than that to complete, an error is +// returned. +func (exp *Exporter) Sync(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.sync(timeout) +} + +func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) { + chanType, ok := reflect.Typeof(chT).(*reflect.ChanType) + if !ok { + return nil, os.ErrorString("not a channel") + } + if dir != Send && dir != Recv { + return nil, os.ErrorString("unknown channel direction") + } + switch chanType.Dir() { + case reflect.BothDir: + case reflect.SendDir: + if dir != Recv { + return nil, os.ErrorString("to import/export with Send, must provide <-chan") + } + case reflect.RecvDir: + if dir != Send { + return nil, os.ErrorString("to import/export with Recv, must provide chan<-") + } + } + return reflect.NewValue(chT).(*reflect.ChanValue), nil +} + +// Export exports a channel of a given type and specified direction. The +// channel to be exported is provided in the call and may be of arbitrary +// channel type. +// Despite the literal signature, the effective signature is +// Export(name string, chT chan T, dir Dir) +func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { + ch, err := checkChan(chT, dir) + if err != nil { + return err + } + exp.mu.Lock() + defer exp.mu.Unlock() + _, present := exp.names[name] + if present { + return os.ErrorString("channel name already being exported:" + name) + } + exp.names[name] = &chanDir{ch, dir} + return nil +} + +// Hangup disassociates the named channel from the Exporter and closes +// the channel. Messages in flight for the channel may be dropped. +func (exp *Exporter) Hangup(name string) os.Error { + exp.mu.Lock() + chDir, ok := exp.names[name] + if ok { + exp.names[name] = nil, false + } + // TODO drop all instances of channel from client sets + exp.mu.Unlock() + if !ok { + return os.ErrorString("netchan export: hangup: no such channel: " + name) + } + chDir.ch.Close() + return nil +} diff --git a/src/cmd/gofix/testdata/reflect.export.go.out b/src/cmd/gofix/testdata/reflect.export.go.out new file mode 100644 index 000000000..2209f04e8 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.export.go.out @@ -0,0 +1,400 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + The netchan package implements type-safe networked channels: + it allows the two ends of a channel to appear on different + computers connected by a network. It does this by transporting + data sent to a channel on one machine so it can be recovered + by a receive of a channel of the same type on the other. + + An exporter publishes a set of channels by name. An importer + connects to the exporting machine and imports the channels + by name. After importing the channels, the two machines can + use the channels in the usual way. + + Networked channels are not synchronized; they always behave + as if they are buffered channels of at least one element. +*/ +package netchan + +// BUG: can't use range clause to receive when using ImportNValues to limit the count. + +import ( + "log" + "io" + "net" + "os" + "reflect" + "strconv" + "sync" +) + +// Export + +// expLog is a logging convenience function. The first argument must be a string. +func expLog(args ...interface{}) { + args[0] = "netchan export: " + args[0].(string) + log.Print(args...) +} + +// An Exporter allows a set of channels to be published on a single +// network port. A single machine may have multiple Exporters +// but they must use different ports. +type Exporter struct { + *clientSet +} + +type expClient struct { + *encDec + exp *Exporter + chans map[int]*netChan // channels in use by client + mu sync.Mutex // protects remaining fields + errored bool // client has been sent an error + seqNum int64 // sequences messages sent to client; has value of highest sent + ackNum int64 // highest sequence number acknowledged + seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu +} + +func newClient(exp *Exporter, conn io.ReadWriter) *expClient { + client := new(expClient) + client.exp = exp + client.encDec = newEncDec(conn) + client.seqNum = 0 + client.ackNum = 0 + client.chans = make(map[int]*netChan) + return client +} + +func (client *expClient) sendError(hdr *header, err string) { + error := &error{err} + expLog("sending error to client:", error.Error) + client.encode(hdr, payError, error) // ignore any encode error, hope client gets it + client.mu.Lock() + client.errored = true + client.mu.Unlock() +} + +func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan { + exp := client.exp + exp.mu.Lock() + ech, ok := exp.names[name] + exp.mu.Unlock() + if !ok { + client.sendError(hdr, "no such channel: "+name) + return nil + } + if ech.dir != dir { + client.sendError(hdr, "wrong direction for channel: "+name) + return nil + } + nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count) + client.chans[hdr.Id] = nch + return nch +} + +func (client *expClient) getChan(hdr *header, dir Dir) *netChan { + nch := client.chans[hdr.Id] + if nch == nil { + return nil + } + if nch.dir != dir { + client.sendError(hdr, "wrong direction for channel: "+nch.name) + } + return nch +} + +// The function run manages sends and receives for a single client. For each +// (client Recv) request, this will launch a serveRecv goroutine to deliver +// the data for that channel, while (client Send) requests are handled as +// data arrives from the client. +func (client *expClient) run() { + hdr := new(header) + hdrValue := reflect.NewValue(hdr) + req := new(request) + reqValue := reflect.NewValue(req) + error := new(error) + for { + *hdr = header{} + if err := client.decode(hdrValue); err != nil { + if err != os.EOF { + expLog("error decoding client header:", err) + } + break + } + switch hdr.PayloadType { + case payRequest: + *req = request{} + if err := client.decode(reqValue); err != nil { + expLog("error decoding client request:", err) + break + } + if req.Size < 1 { + panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values") + } + switch req.Dir { + case Recv: + // look up channel before calling serveRecv to + // avoid a lock around client.chans. + if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil { + go client.serveRecv(nch, *hdr, req.Count) + } + case Send: + client.newChan(hdr, Recv, req.Name, req.Size, req.Count) + // The actual sends will have payload type payData. + // TODO: manage the count? + default: + error.Error = "request: can't handle channel direction" + expLog(error.Error, req.Dir) + client.encode(hdr, payError, error) + } + case payData: + client.serveSend(*hdr) + case payClosed: + client.serveClosed(*hdr) + case payAck: + client.mu.Lock() + if client.ackNum != hdr.SeqNum-1 { + // Since the sequence number is incremented and the message is sent + // in a single instance of locking client.mu, the messages are guaranteed + // to be sent in order. Therefore receipt of acknowledgement N means + // all messages <=N have been seen by the recipient. We check anyway. + expLog("sequence out of order:", client.ackNum, hdr.SeqNum) + } + if client.ackNum < hdr.SeqNum { // If there has been an error, don't back up the count. + client.ackNum = hdr.SeqNum + } + client.mu.Unlock() + case payAckSend: + if nch := client.getChan(hdr, Send); nch != nil { + nch.acked() + } + default: + log.Fatal("netchan export: unknown payload type", hdr.PayloadType) + } + } + client.exp.delClient(client) +} + +// Send all the data on a single channel to a client asking for a Recv. +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) { + for { + val, ok := nch.recv() + if !ok { + if err := client.encode(&hdr, payClosed, nil); err != nil { + expLog("error encoding server closed message:", err) + } + break + } + // We hold the lock during transmission to guarantee messages are + // sent in sequence number order. Also, we increment first so the + // value of client.SeqNum is the value of the highest used sequence + // number, not one beyond. + client.mu.Lock() + client.seqNum++ + hdr.SeqNum = client.seqNum + client.seqLock.Lock() // guarantee ordering of messages + client.mu.Unlock() + err := client.encode(&hdr, payData, val.Interface()) + client.seqLock.Unlock() + if err != nil { + expLog("error encoding client response:", err) + client.sendError(&hdr, err.String()) + break + } + // Negative count means run forever. + if count >= 0 { + if count--; count <= 0 { + break + } + } + } +} + +// Receive and deliver locally one item from a client asking for a Send +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveSend(hdr header) { + nch := client.getChan(&hdr, Recv) + if nch == nil { + return + } + // Create a new value for each received item. + val := reflect.Zero(nch.ch.Type().Elem()) + if err := client.decode(val); err != nil { + expLog("value decode:", err, "; type ", nch.ch.Type()) + return + } + nch.send(val) +} + +// Report that client has closed the channel that is sending to us. +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveClosed(hdr header) { + nch := client.getChan(&hdr, Recv) + if nch == nil { + return + } + nch.close() +} + +func (client *expClient) unackedCount() int64 { + client.mu.Lock() + n := client.seqNum - client.ackNum + client.mu.Unlock() + return n +} + +func (client *expClient) seq() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +func (client *expClient) ack() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +// Serve waits for incoming connections on the listener +// and serves the Exporter's channels on each. +// It blocks until the listener is closed. +func (exp *Exporter) Serve(listener net.Listener) { + for { + conn, err := listener.Accept() + if err != nil { + expLog("listen:", err) + break + } + go exp.ServeConn(conn) + } +} + +// ServeConn exports the Exporter's channels on conn. +// It blocks until the connection is terminated. +func (exp *Exporter) ServeConn(conn io.ReadWriter) { + exp.addClient(conn).run() +} + +// NewExporter creates a new Exporter that exports a set of channels. +func NewExporter() *Exporter { + e := &Exporter{ + clientSet: &clientSet{ + names: make(map[string]*chanDir), + clients: make(map[unackedCounter]bool), + }, + } + return e +} + +// ListenAndServe exports the exporter's channels through the +// given network and local address defined as in net.Listen. +func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error { + listener, err := net.Listen(network, localaddr) + if err != nil { + return err + } + go exp.Serve(listener) + return nil +} + +// addClient creates a new expClient and records its existence +func (exp *Exporter) addClient(conn io.ReadWriter) *expClient { + client := newClient(exp, conn) + exp.mu.Lock() + exp.clients[client] = true + exp.mu.Unlock() + return client +} + +// delClient forgets the client existed +func (exp *Exporter) delClient(client *expClient) { + exp.mu.Lock() + exp.clients[client] = false, false + exp.mu.Unlock() +} + +// Drain waits until all messages sent from this exporter/importer, including +// those not yet sent to any client and possibly including those sent while +// Drain was executing, have been received by the importer. In short, it +// waits until all the exporter's messages have been received by a client. +// If the timeout (measured in nanoseconds) is positive and Drain takes +// longer than that to complete, an error is returned. +func (exp *Exporter) Drain(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.drain(timeout) +} + +// Sync waits until all clients of the exporter have received the messages +// that were sent at the time Sync was invoked. Unlike Drain, it does not +// wait for messages sent while it is running or messages that have not been +// dispatched to any client. If the timeout (measured in nanoseconds) is +// positive and Sync takes longer than that to complete, an error is +// returned. +func (exp *Exporter) Sync(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.sync(timeout) +} + +func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { + chanType := reflect.Typeof(chT) + if chanType.Kind() != reflect.Chan { + return reflect.Value{}, os.ErrorString("not a channel") + } + if dir != Send && dir != Recv { + return reflect.Value{}, os.ErrorString("unknown channel direction") + } + switch chanType.ChanDir() { + case reflect.BothDir: + case reflect.SendDir: + if dir != Recv { + return reflect.Value{}, os.ErrorString("to import/export with Send, must provide <-chan") + } + case reflect.RecvDir: + if dir != Send { + return reflect.Value{}, os.ErrorString("to import/export with Recv, must provide chan<-") + } + } + return reflect.NewValue(chT), nil +} + +// Export exports a channel of a given type and specified direction. The +// channel to be exported is provided in the call and may be of arbitrary +// channel type. +// Despite the literal signature, the effective signature is +// Export(name string, chT chan T, dir Dir) +func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { + ch, err := checkChan(chT, dir) + if err != nil { + return err + } + exp.mu.Lock() + defer exp.mu.Unlock() + _, present := exp.names[name] + if present { + return os.ErrorString("channel name already being exported:" + name) + } + exp.names[name] = &chanDir{ch, dir} + return nil +} + +// Hangup disassociates the named channel from the Exporter and closes +// the channel. Messages in flight for the channel may be dropped. +func (exp *Exporter) Hangup(name string) os.Error { + exp.mu.Lock() + chDir, ok := exp.names[name] + if ok { + exp.names[name] = nil, false + } + // TODO drop all instances of channel from client sets + exp.mu.Unlock() + if !ok { + return os.ErrorString("netchan export: hangup: no such channel: " + name) + } + chDir.ch.Close() + return nil +} diff --git a/src/cmd/gofix/testdata/reflect.print.go.in b/src/cmd/gofix/testdata/reflect.print.go.in new file mode 100644 index 000000000..194bcdb3b --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.print.go.in @@ -0,0 +1,945 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fmt + +import ( + "bytes" + "io" + "os" + "reflect" + "utf8" +) + +// Some constants in the form of bytes, to avoid string overhead. +// Needlessly fastidious, I suppose. +var ( + commaSpaceBytes = []byte(", ") + nilAngleBytes = []byte("<nil>") + nilParenBytes = []byte("(nil)") + nilBytes = []byte("nil") + mapBytes = []byte("map[") + missingBytes = []byte("(MISSING)") + extraBytes = []byte("%!(EXTRA ") + irparenBytes = []byte("i)") + bytesBytes = []byte("[]byte{") + widthBytes = []byte("%!(BADWIDTH)") + precBytes = []byte("%!(BADPREC)") + noVerbBytes = []byte("%!(NOVERB)") +) + +// State represents the printer state passed to custom formatters. +// It provides access to the io.Writer interface plus information about +// the flags and options for the operand's format specifier. +type State interface { + // Write is the function to call to emit formatted output to be printed. + Write(b []byte) (ret int, err os.Error) + // Width returns the value of the width option and whether it has been set. + Width() (wid int, ok bool) + // Precision returns the value of the precision option and whether it has been set. + Precision() (prec int, ok bool) + + // Flag returns whether the flag c, a character, has been set. + Flag(int) bool +} + +// Formatter is the interface implemented by values with a custom formatter. +// The implementation of Format may call Sprintf or Fprintf(f) etc. +// to generate its output. +type Formatter interface { + Format(f State, c int) +} + +// Stringer is implemented by any value that has a String method(), +// which defines the ``native'' format for that value. +// The String method is used to print values passed as an operand +// to a %s or %v format or to an unformatted printer such as Print. +type Stringer interface { + String() string +} + +// GoStringer is implemented by any value that has a GoString() method, +// which defines the Go syntax for that value. +// The GoString method is used to print values passed as an operand +// to a %#v format. +type GoStringer interface { + GoString() string +} + +type pp struct { + n int + buf bytes.Buffer + runeBuf [utf8.UTFMax]byte + fmt fmt +} + +// A cache holds a set of reusable objects. +// The buffered channel holds the currently available objects. +// If more are needed, the cache creates them by calling new. +type cache struct { + saved chan interface{} + new func() interface{} +} + +func (c *cache) put(x interface{}) { + select { + case c.saved <- x: + // saved in cache + default: + // discard + } +} + +func (c *cache) get() interface{} { + select { + case x := <-c.saved: + return x // reused from cache + default: + return c.new() + } + panic("not reached") +} + +func newCache(f func() interface{}) *cache { + return &cache{make(chan interface{}, 100), f} +} + +var ppFree = newCache(func() interface{} { return new(pp) }) + +// Allocate a new pp struct or grab a cached one. +func newPrinter() *pp { + p := ppFree.get().(*pp) + p.fmt.init(&p.buf) + return p +} + +// Save used pp structs in ppFree; avoids an allocation per invocation. +func (p *pp) free() { + // Don't hold on to pp structs with large buffers. + if cap(p.buf.Bytes()) > 1024 { + return + } + p.buf.Reset() + ppFree.put(p) +} + +func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent } + +func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent } + +func (p *pp) Flag(b int) bool { + switch b { + case '-': + return p.fmt.minus + case '+': + return p.fmt.plus + case '#': + return p.fmt.sharp + case ' ': + return p.fmt.space + case '0': + return p.fmt.zero + } + return false +} + +func (p *pp) add(c int) { + p.buf.WriteRune(c) +} + +// Implement Write so we can call Fprintf on a pp (through State), for +// recursive use in custom verbs. +func (p *pp) Write(b []byte) (ret int, err os.Error) { + return p.buf.Write(b) +} + +// These routines end in 'f' and take a format string. + +// Fprintf formats according to a format specifier and writes to w. +// It returns the number of bytes written and any write error encountered. +func Fprintf(w io.Writer, format string, a ...interface{}) (n int, error os.Error) { + p := newPrinter() + p.doPrintf(format, a) + n64, error := p.buf.WriteTo(w) + p.free() + return int(n64), error +} + +// Printf formats according to a format specifier and writes to standard output. +// It returns the number of bytes written and any write error encountered. +func Printf(format string, a ...interface{}) (n int, errno os.Error) { + n, errno = Fprintf(os.Stdout, format, a...) + return n, errno +} + +// Sprintf formats according to a format specifier and returns the resulting string. +func Sprintf(format string, a ...interface{}) string { + p := newPrinter() + p.doPrintf(format, a) + s := p.buf.String() + p.free() + return s +} + +// Errorf formats according to a format specifier and returns the string +// converted to an os.ErrorString, which satisfies the os.Error interface. +func Errorf(format string, a ...interface{}) os.Error { + return os.ErrorString(Sprintf(format, a...)) +} + +// These routines do not take a format string + +// Fprint formats using the default formats for its operands and writes to w. +// Spaces are added between operands when neither is a string. +// It returns the number of bytes written and any write error encountered. +func Fprint(w io.Writer, a ...interface{}) (n int, error os.Error) { + p := newPrinter() + p.doPrint(a, false, false) + n64, error := p.buf.WriteTo(w) + p.free() + return int(n64), error +} + +// Print formats using the default formats for its operands and writes to standard output. +// Spaces are added between operands when neither is a string. +// It returns the number of bytes written and any write error encountered. +func Print(a ...interface{}) (n int, errno os.Error) { + n, errno = Fprint(os.Stdout, a...) + return n, errno +} + +// Sprint formats using the default formats for its operands and returns the resulting string. +// Spaces are added between operands when neither is a string. +func Sprint(a ...interface{}) string { + p := newPrinter() + p.doPrint(a, false, false) + s := p.buf.String() + p.free() + return s +} + +// These routines end in 'ln', do not take a format string, +// always add spaces between operands, and add a newline +// after the last operand. + +// Fprintln formats using the default formats for its operands and writes to w. +// Spaces are always added between operands and a newline is appended. +// It returns the number of bytes written and any write error encountered. +func Fprintln(w io.Writer, a ...interface{}) (n int, error os.Error) { + p := newPrinter() + p.doPrint(a, true, true) + n64, error := p.buf.WriteTo(w) + p.free() + return int(n64), error +} + +// Println formats using the default formats for its operands and writes to standard output. +// Spaces are always added between operands and a newline is appended. +// It returns the number of bytes written and any write error encountered. +func Println(a ...interface{}) (n int, errno os.Error) { + n, errno = Fprintln(os.Stdout, a...) + return n, errno +} + +// Sprintln formats using the default formats for its operands and returns the resulting string. +// Spaces are always added between operands and a newline is appended. +func Sprintln(a ...interface{}) string { + p := newPrinter() + p.doPrint(a, true, true) + s := p.buf.String() + p.free() + return s +} + + +// Get the i'th arg of the struct value. +// If the arg itself is an interface, return a value for +// the thing inside the interface, not the interface itself. +func getField(v *reflect.StructValue, i int) reflect.Value { + val := v.Field(i) + if i, ok := val.(*reflect.InterfaceValue); ok { + if inter := i.Interface(); inter != nil { + return reflect.NewValue(inter) + } + } + return val +} + +// Convert ASCII to integer. n is 0 (and got is false) if no number present. +func parsenum(s string, start, end int) (num int, isnum bool, newi int) { + if start >= end { + return 0, false, end + } + for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ { + num = num*10 + int(s[newi]-'0') + isnum = true + } + return +} + +func (p *pp) unknownType(v interface{}) { + if v == nil { + p.buf.Write(nilAngleBytes) + return + } + p.buf.WriteByte('?') + p.buf.WriteString(reflect.Typeof(v).String()) + p.buf.WriteByte('?') +} + +func (p *pp) badVerb(verb int, val interface{}) { + p.add('%') + p.add('!') + p.add(verb) + p.add('(') + if val == nil { + p.buf.Write(nilAngleBytes) + } else { + p.buf.WriteString(reflect.Typeof(val).String()) + p.add('=') + p.printField(val, 'v', false, false, 0) + } + p.add(')') +} + +func (p *pp) fmtBool(v bool, verb int, value interface{}) { + switch verb { + case 't', 'v': + p.fmt.fmt_boolean(v) + default: + p.badVerb(verb, value) + } +} + +// fmtC formats a rune for the 'c' format. +func (p *pp) fmtC(c int64) { + rune := int(c) // Check for overflow. + if int64(rune) != c { + rune = utf8.RuneError + } + w := utf8.EncodeRune(p.runeBuf[0:utf8.UTFMax], rune) + p.fmt.pad(p.runeBuf[0:w]) +} + +func (p *pp) fmtInt64(v int64, verb int, value interface{}) { + switch verb { + case 'b': + p.fmt.integer(v, 2, signed, ldigits) + case 'c': + p.fmtC(v) + case 'd', 'v': + p.fmt.integer(v, 10, signed, ldigits) + case 'o': + p.fmt.integer(v, 8, signed, ldigits) + case 'x': + p.fmt.integer(v, 16, signed, ldigits) + case 'U': + p.fmtUnicode(v) + case 'X': + p.fmt.integer(v, 16, signed, udigits) + default: + p.badVerb(verb, value) + } +} + +// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or +// not, as requested, by temporarily setting the sharp flag. +func (p *pp) fmt0x64(v uint64, leading0x bool) { + sharp := p.fmt.sharp + p.fmt.sharp = leading0x + p.fmt.integer(int64(v), 16, unsigned, ldigits) + p.fmt.sharp = sharp +} + +// fmtUnicode formats a uint64 in U+1234 form by +// temporarily turning on the unicode flag and tweaking the precision. +func (p *pp) fmtUnicode(v int64) { + precPresent := p.fmt.precPresent + prec := p.fmt.prec + if !precPresent { + // If prec is already set, leave it alone; otherwise 4 is minimum. + p.fmt.prec = 4 + p.fmt.precPresent = true + } + p.fmt.unicode = true // turn on U+ + p.fmt.integer(int64(v), 16, unsigned, udigits) + p.fmt.unicode = false + p.fmt.prec = prec + p.fmt.precPresent = precPresent +} + +func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { + switch verb { + case 'b': + p.fmt.integer(int64(v), 2, unsigned, ldigits) + case 'c': + p.fmtC(int64(v)) + case 'd': + p.fmt.integer(int64(v), 10, unsigned, ldigits) + case 'v': + if goSyntax { + p.fmt0x64(v, true) + } else { + p.fmt.integer(int64(v), 10, unsigned, ldigits) + } + case 'o': + p.fmt.integer(int64(v), 8, unsigned, ldigits) + case 'x': + p.fmt.integer(int64(v), 16, unsigned, ldigits) + case 'X': + p.fmt.integer(int64(v), 16, unsigned, udigits) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtFloat32(v float32, verb int, value interface{}) { + switch verb { + case 'b': + p.fmt.fmt_fb32(v) + case 'e': + p.fmt.fmt_e32(v) + case 'E': + p.fmt.fmt_E32(v) + case 'f': + p.fmt.fmt_f32(v) + case 'g', 'v': + p.fmt.fmt_g32(v) + case 'G': + p.fmt.fmt_G32(v) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtFloat64(v float64, verb int, value interface{}) { + switch verb { + case 'b': + p.fmt.fmt_fb64(v) + case 'e': + p.fmt.fmt_e64(v) + case 'E': + p.fmt.fmt_E64(v) + case 'f': + p.fmt.fmt_f64(v) + case 'g', 'v': + p.fmt.fmt_g64(v) + case 'G': + p.fmt.fmt_G64(v) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) { + switch verb { + case 'e', 'E', 'f', 'F', 'g', 'G': + p.fmt.fmt_c64(v, verb) + case 'v': + p.fmt.fmt_c64(v, 'g') + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) { + switch verb { + case 'e', 'E', 'f', 'F', 'g', 'G': + p.fmt.fmt_c128(v, verb) + case 'v': + p.fmt.fmt_c128(v, 'g') + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) { + switch verb { + case 'v': + if goSyntax { + p.fmt.fmt_q(v) + } else { + p.fmt.fmt_s(v) + } + case 's': + p.fmt.fmt_s(v) + case 'x': + p.fmt.fmt_sx(v) + case 'X': + p.fmt.fmt_sX(v) + case 'q': + p.fmt.fmt_q(v) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) { + if verb == 'v' || verb == 'd' { + if goSyntax { + p.buf.Write(bytesBytes) + } else { + p.buf.WriteByte('[') + } + for i, c := range v { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + p.printField(c, 'v', p.fmt.plus, goSyntax, depth+1) + } + if goSyntax { + p.buf.WriteByte('}') + } else { + p.buf.WriteByte(']') + } + return + } + s := string(v) + switch verb { + case 's': + p.fmt.fmt_s(s) + case 'x': + p.fmt.fmt_sx(s) + case 'X': + p.fmt.fmt_sX(s) + case 'q': + p.fmt.fmt_q(s) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) { + var u uintptr + switch value.(type) { + case *reflect.ChanValue, *reflect.FuncValue, *reflect.MapValue, *reflect.PtrValue, *reflect.SliceValue, *reflect.UnsafePointerValue: + u = value.Pointer() + default: + p.badVerb(verb, field) + return + } + if goSyntax { + p.add('(') + p.buf.WriteString(reflect.Typeof(field).String()) + p.add(')') + p.add('(') + if u == 0 { + p.buf.Write(nilBytes) + } else { + p.fmt0x64(uint64(u), true) + } + p.add(')') + } else { + p.fmt0x64(uint64(u), !p.fmt.sharp) + } +} + +var ( + intBits = reflect.Typeof(0).Bits() + floatBits = reflect.Typeof(0.0).Bits() + complexBits = reflect.Typeof(1i).Bits() + uintptrBits = reflect.Typeof(uintptr(0)).Bits() +) + +func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) { + if field == nil { + if verb == 'T' || verb == 'v' { + p.buf.Write(nilAngleBytes) + } else { + p.badVerb(verb, field) + } + return false + } + + // Special processing considerations. + // %T (the value's type) and %p (its address) are special; we always do them first. + switch verb { + case 'T': + p.printField(reflect.Typeof(field).String(), 's', false, false, 0) + return false + case 'p': + p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax) + return false + } + // Is it a Formatter? + if formatter, ok := field.(Formatter); ok { + formatter.Format(p, verb) + return false // this value is not a string + + } + // Must not touch flags before Formatter looks at them. + if plus { + p.fmt.plus = false + } + // If we're doing Go syntax and the field knows how to supply it, take care of it now. + if goSyntax { + p.fmt.sharp = false + if stringer, ok := field.(GoStringer); ok { + // Print the result of GoString unadorned. + p.fmtString(stringer.GoString(), 's', false, field) + return false // this value is not a string + } + } else { + // Is it a Stringer? + if stringer, ok := field.(Stringer); ok { + p.printField(stringer.String(), verb, plus, false, depth) + return false // this value is not a string + } + } + + // Some types can be done without reflection. + switch f := field.(type) { + case bool: + p.fmtBool(f, verb, field) + return false + case float32: + p.fmtFloat32(f, verb, field) + return false + case float64: + p.fmtFloat64(f, verb, field) + return false + case complex64: + p.fmtComplex64(complex64(f), verb, field) + return false + case complex128: + p.fmtComplex128(f, verb, field) + return false + case int: + p.fmtInt64(int64(f), verb, field) + return false + case int8: + p.fmtInt64(int64(f), verb, field) + return false + case int16: + p.fmtInt64(int64(f), verb, field) + return false + case int32: + p.fmtInt64(int64(f), verb, field) + return false + case int64: + p.fmtInt64(f, verb, field) + return false + case uint: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint8: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint16: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint32: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint64: + p.fmtUint64(f, verb, goSyntax, field) + return false + case uintptr: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case string: + p.fmtString(f, verb, goSyntax, field) + return verb == 's' || verb == 'v' + case []byte: + p.fmtBytes(f, verb, goSyntax, depth, field) + return verb == 's' + } + + // Need to use reflection + value := reflect.NewValue(field) + +BigSwitch: + switch f := value.(type) { + case *reflect.BoolValue: + p.fmtBool(f.Get(), verb, field) + case *reflect.IntValue: + p.fmtInt64(f.Get(), verb, field) + case *reflect.UintValue: + p.fmtUint64(uint64(f.Get()), verb, goSyntax, field) + case *reflect.FloatValue: + if f.Type().Size() == 4 { + p.fmtFloat32(float32(f.Get()), verb, field) + } else { + p.fmtFloat64(float64(f.Get()), verb, field) + } + case *reflect.ComplexValue: + if f.Type().Size() == 8 { + p.fmtComplex64(complex64(f.Get()), verb, field) + } else { + p.fmtComplex128(complex128(f.Get()), verb, field) + } + case *reflect.StringValue: + p.fmtString(f.Get(), verb, goSyntax, field) + case *reflect.MapValue: + if goSyntax { + p.buf.WriteString(f.Type().String()) + p.buf.WriteByte('{') + } else { + p.buf.Write(mapBytes) + } + keys := f.Keys() + for i, key := range keys { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + p.printField(key.Interface(), verb, plus, goSyntax, depth+1) + p.buf.WriteByte(':') + p.printField(f.Elem(key).Interface(), verb, plus, goSyntax, depth+1) + } + if goSyntax { + p.buf.WriteByte('}') + } else { + p.buf.WriteByte(']') + } + case *reflect.StructValue: + if goSyntax { + p.buf.WriteString(reflect.Typeof(field).String()) + } + p.add('{') + v := f + t := v.Type().(*reflect.StructType) + for i := 0; i < v.NumField(); i++ { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + if plus || goSyntax { + if f := t.Field(i); f.Name != "" { + p.buf.WriteString(f.Name) + p.buf.WriteByte(':') + } + } + p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1) + } + p.buf.WriteByte('}') + case *reflect.InterfaceValue: + value := f.Elem() + if value == nil { + if goSyntax { + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.Write(nilParenBytes) + } else { + p.buf.Write(nilAngleBytes) + } + } else { + return p.printField(value.Interface(), verb, plus, goSyntax, depth+1) + } + case reflect.ArrayOrSliceValue: + // Byte slices are special. + if f.Type().(reflect.ArrayOrSliceType).Elem().Kind() == reflect.Uint8 { + // We know it's a slice of bytes, but we also know it does not have static type + // []byte, or it would have been caught above. Therefore we cannot convert + // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have + // that type, and we can't write an expression of the right type and do a + // conversion because we don't have a static way to write the right type. + // So we build a slice by hand. This is a rare case but it would be nice + // if reflection could help a little more. + bytes := make([]byte, f.Len()) + for i := range bytes { + bytes[i] = byte(f.Elem(i).(*reflect.UintValue).Get()) + } + p.fmtBytes(bytes, verb, goSyntax, depth, field) + return verb == 's' + } + if goSyntax { + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteByte('{') + } else { + p.buf.WriteByte('[') + } + for i := 0; i < f.Len(); i++ { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + p.printField(f.Elem(i).Interface(), verb, plus, goSyntax, depth+1) + } + if goSyntax { + p.buf.WriteByte('}') + } else { + p.buf.WriteByte(']') + } + case *reflect.PtrValue: + v := f.Get() + // pointer to array or slice or struct? ok at top level + // but not embedded (avoid loops) + if v != 0 && depth == 0 { + switch a := f.Elem().(type) { + case reflect.ArrayOrSliceValue: + p.buf.WriteByte('&') + p.printField(a.Interface(), verb, plus, goSyntax, depth+1) + break BigSwitch + case *reflect.StructValue: + p.buf.WriteByte('&') + p.printField(a.Interface(), verb, plus, goSyntax, depth+1) + break BigSwitch + } + } + if goSyntax { + p.buf.WriteByte('(') + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteByte(')') + p.buf.WriteByte('(') + if v == 0 { + p.buf.Write(nilBytes) + } else { + p.fmt0x64(uint64(v), true) + } + p.buf.WriteByte(')') + break + } + if v == 0 { + p.buf.Write(nilAngleBytes) + break + } + p.fmt0x64(uint64(v), true) + case *reflect.ChanValue, *reflect.FuncValue, *reflect.UnsafePointerValue: + p.fmtPointer(field, value, verb, goSyntax) + default: + p.unknownType(f) + } + return false +} + +// intFromArg gets the fieldnumth element of a. On return, isInt reports whether the argument has type int. +func intFromArg(a []interface{}, end, i, fieldnum int) (num int, isInt bool, newi, newfieldnum int) { + newi, newfieldnum = end, fieldnum + if i < end && fieldnum < len(a) { + num, isInt = a[fieldnum].(int) + newi, newfieldnum = i+1, fieldnum+1 + } + return +} + +func (p *pp) doPrintf(format string, a []interface{}) { + end := len(format) + fieldnum := 0 // we process one field per non-trivial format + for i := 0; i < end; { + lasti := i + for i < end && format[i] != '%' { + i++ + } + if i > lasti { + p.buf.WriteString(format[lasti:i]) + } + if i >= end { + // done processing format string + break + } + + // Process one verb + i++ + // flags and widths + p.fmt.clearflags() + F: + for ; i < end; i++ { + switch format[i] { + case '#': + p.fmt.sharp = true + case '0': + p.fmt.zero = true + case '+': + p.fmt.plus = true + case '-': + p.fmt.minus = true + case ' ': + p.fmt.space = true + default: + break F + } + } + // do we have width? + if i < end && format[i] == '*' { + p.fmt.wid, p.fmt.widPresent, i, fieldnum = intFromArg(a, end, i, fieldnum) + if !p.fmt.widPresent { + p.buf.Write(widthBytes) + } + } else { + p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end) + } + // do we have precision? + if i < end && format[i] == '.' { + if format[i+1] == '*' { + p.fmt.prec, p.fmt.precPresent, i, fieldnum = intFromArg(a, end, i+1, fieldnum) + if !p.fmt.precPresent { + p.buf.Write(precBytes) + } + } else { + p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i+1, end) + } + } + if i >= end { + p.buf.Write(noVerbBytes) + continue + } + c, w := utf8.DecodeRuneInString(format[i:]) + i += w + // percent is special - absorbs no operand + if c == '%' { + p.buf.WriteByte('%') // We ignore width and prec. + continue + } + if fieldnum >= len(a) { // out of operands + p.buf.WriteByte('%') + p.add(c) + p.buf.Write(missingBytes) + continue + } + field := a[fieldnum] + fieldnum++ + + goSyntax := c == 'v' && p.fmt.sharp + plus := c == 'v' && p.fmt.plus + p.printField(field, c, plus, goSyntax, 0) + } + + if fieldnum < len(a) { + p.buf.Write(extraBytes) + for ; fieldnum < len(a); fieldnum++ { + field := a[fieldnum] + if field != nil { + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteByte('=') + } + p.printField(field, 'v', false, false, 0) + if fieldnum+1 < len(a) { + p.buf.Write(commaSpaceBytes) + } + } + p.buf.WriteByte(')') + } +} + +func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) { + prevString := false + for fieldnum := 0; fieldnum < len(a); fieldnum++ { + p.fmt.clearflags() + // always add spaces if we're doing println + field := a[fieldnum] + if fieldnum > 0 { + isString := field != nil && reflect.Typeof(field).Kind() == reflect.String + if addspace || !isString && !prevString { + p.buf.WriteByte(' ') + } + } + prevString = p.printField(field, 'v', false, false, 0) + } + if addnewline { + p.buf.WriteByte('\n') + } +} diff --git a/src/cmd/gofix/testdata/reflect.print.go.out b/src/cmd/gofix/testdata/reflect.print.go.out new file mode 100644 index 000000000..e3dc775cf --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.print.go.out @@ -0,0 +1,945 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fmt + +import ( + "bytes" + "io" + "os" + "reflect" + "utf8" +) + +// Some constants in the form of bytes, to avoid string overhead. +// Needlessly fastidious, I suppose. +var ( + commaSpaceBytes = []byte(", ") + nilAngleBytes = []byte("<nil>") + nilParenBytes = []byte("(nil)") + nilBytes = []byte("nil") + mapBytes = []byte("map[") + missingBytes = []byte("(MISSING)") + extraBytes = []byte("%!(EXTRA ") + irparenBytes = []byte("i)") + bytesBytes = []byte("[]byte{") + widthBytes = []byte("%!(BADWIDTH)") + precBytes = []byte("%!(BADPREC)") + noVerbBytes = []byte("%!(NOVERB)") +) + +// State represents the printer state passed to custom formatters. +// It provides access to the io.Writer interface plus information about +// the flags and options for the operand's format specifier. +type State interface { + // Write is the function to call to emit formatted output to be printed. + Write(b []byte) (ret int, err os.Error) + // Width returns the value of the width option and whether it has been set. + Width() (wid int, ok bool) + // Precision returns the value of the precision option and whether it has been set. + Precision() (prec int, ok bool) + + // Flag returns whether the flag c, a character, has been set. + Flag(int) bool +} + +// Formatter is the interface implemented by values with a custom formatter. +// The implementation of Format may call Sprintf or Fprintf(f) etc. +// to generate its output. +type Formatter interface { + Format(f State, c int) +} + +// Stringer is implemented by any value that has a String method(), +// which defines the ``native'' format for that value. +// The String method is used to print values passed as an operand +// to a %s or %v format or to an unformatted printer such as Print. +type Stringer interface { + String() string +} + +// GoStringer is implemented by any value that has a GoString() method, +// which defines the Go syntax for that value. +// The GoString method is used to print values passed as an operand +// to a %#v format. +type GoStringer interface { + GoString() string +} + +type pp struct { + n int + buf bytes.Buffer + runeBuf [utf8.UTFMax]byte + fmt fmt +} + +// A cache holds a set of reusable objects. +// The buffered channel holds the currently available objects. +// If more are needed, the cache creates them by calling new. +type cache struct { + saved chan interface{} + new func() interface{} +} + +func (c *cache) put(x interface{}) { + select { + case c.saved <- x: + // saved in cache + default: + // discard + } +} + +func (c *cache) get() interface{} { + select { + case x := <-c.saved: + return x // reused from cache + default: + return c.new() + } + panic("not reached") +} + +func newCache(f func() interface{}) *cache { + return &cache{make(chan interface{}, 100), f} +} + +var ppFree = newCache(func() interface{} { return new(pp) }) + +// Allocate a new pp struct or grab a cached one. +func newPrinter() *pp { + p := ppFree.get().(*pp) + p.fmt.init(&p.buf) + return p +} + +// Save used pp structs in ppFree; avoids an allocation per invocation. +func (p *pp) free() { + // Don't hold on to pp structs with large buffers. + if cap(p.buf.Bytes()) > 1024 { + return + } + p.buf.Reset() + ppFree.put(p) +} + +func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent } + +func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent } + +func (p *pp) Flag(b int) bool { + switch b { + case '-': + return p.fmt.minus + case '+': + return p.fmt.plus + case '#': + return p.fmt.sharp + case ' ': + return p.fmt.space + case '0': + return p.fmt.zero + } + return false +} + +func (p *pp) add(c int) { + p.buf.WriteRune(c) +} + +// Implement Write so we can call Fprintf on a pp (through State), for +// recursive use in custom verbs. +func (p *pp) Write(b []byte) (ret int, err os.Error) { + return p.buf.Write(b) +} + +// These routines end in 'f' and take a format string. + +// Fprintf formats according to a format specifier and writes to w. +// It returns the number of bytes written and any write error encountered. +func Fprintf(w io.Writer, format string, a ...interface{}) (n int, error os.Error) { + p := newPrinter() + p.doPrintf(format, a) + n64, error := p.buf.WriteTo(w) + p.free() + return int(n64), error +} + +// Printf formats according to a format specifier and writes to standard output. +// It returns the number of bytes written and any write error encountered. +func Printf(format string, a ...interface{}) (n int, errno os.Error) { + n, errno = Fprintf(os.Stdout, format, a...) + return n, errno +} + +// Sprintf formats according to a format specifier and returns the resulting string. +func Sprintf(format string, a ...interface{}) string { + p := newPrinter() + p.doPrintf(format, a) + s := p.buf.String() + p.free() + return s +} + +// Errorf formats according to a format specifier and returns the string +// converted to an os.ErrorString, which satisfies the os.Error interface. +func Errorf(format string, a ...interface{}) os.Error { + return os.ErrorString(Sprintf(format, a...)) +} + +// These routines do not take a format string + +// Fprint formats using the default formats for its operands and writes to w. +// Spaces are added between operands when neither is a string. +// It returns the number of bytes written and any write error encountered. +func Fprint(w io.Writer, a ...interface{}) (n int, error os.Error) { + p := newPrinter() + p.doPrint(a, false, false) + n64, error := p.buf.WriteTo(w) + p.free() + return int(n64), error +} + +// Print formats using the default formats for its operands and writes to standard output. +// Spaces are added between operands when neither is a string. +// It returns the number of bytes written and any write error encountered. +func Print(a ...interface{}) (n int, errno os.Error) { + n, errno = Fprint(os.Stdout, a...) + return n, errno +} + +// Sprint formats using the default formats for its operands and returns the resulting string. +// Spaces are added between operands when neither is a string. +func Sprint(a ...interface{}) string { + p := newPrinter() + p.doPrint(a, false, false) + s := p.buf.String() + p.free() + return s +} + +// These routines end in 'ln', do not take a format string, +// always add spaces between operands, and add a newline +// after the last operand. + +// Fprintln formats using the default formats for its operands and writes to w. +// Spaces are always added between operands and a newline is appended. +// It returns the number of bytes written and any write error encountered. +func Fprintln(w io.Writer, a ...interface{}) (n int, error os.Error) { + p := newPrinter() + p.doPrint(a, true, true) + n64, error := p.buf.WriteTo(w) + p.free() + return int(n64), error +} + +// Println formats using the default formats for its operands and writes to standard output. +// Spaces are always added between operands and a newline is appended. +// It returns the number of bytes written and any write error encountered. +func Println(a ...interface{}) (n int, errno os.Error) { + n, errno = Fprintln(os.Stdout, a...) + return n, errno +} + +// Sprintln formats using the default formats for its operands and returns the resulting string. +// Spaces are always added between operands and a newline is appended. +func Sprintln(a ...interface{}) string { + p := newPrinter() + p.doPrint(a, true, true) + s := p.buf.String() + p.free() + return s +} + + +// Get the i'th arg of the struct value. +// If the arg itself is an interface, return a value for +// the thing inside the interface, not the interface itself. +func getField(v reflect.Value, i int) reflect.Value { + val := v.Field(i) + if i := val; i.Kind() == reflect.Interface { + if inter := i.Interface(); inter != nil { + return reflect.NewValue(inter) + } + } + return val +} + +// Convert ASCII to integer. n is 0 (and got is false) if no number present. +func parsenum(s string, start, end int) (num int, isnum bool, newi int) { + if start >= end { + return 0, false, end + } + for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ { + num = num*10 + int(s[newi]-'0') + isnum = true + } + return +} + +func (p *pp) unknownType(v interface{}) { + if v == nil { + p.buf.Write(nilAngleBytes) + return + } + p.buf.WriteByte('?') + p.buf.WriteString(reflect.Typeof(v).String()) + p.buf.WriteByte('?') +} + +func (p *pp) badVerb(verb int, val interface{}) { + p.add('%') + p.add('!') + p.add(verb) + p.add('(') + if val == nil { + p.buf.Write(nilAngleBytes) + } else { + p.buf.WriteString(reflect.Typeof(val).String()) + p.add('=') + p.printField(val, 'v', false, false, 0) + } + p.add(')') +} + +func (p *pp) fmtBool(v bool, verb int, value interface{}) { + switch verb { + case 't', 'v': + p.fmt.fmt_boolean(v) + default: + p.badVerb(verb, value) + } +} + +// fmtC formats a rune for the 'c' format. +func (p *pp) fmtC(c int64) { + rune := int(c) // Check for overflow. + if int64(rune) != c { + rune = utf8.RuneError + } + w := utf8.EncodeRune(p.runeBuf[0:utf8.UTFMax], rune) + p.fmt.pad(p.runeBuf[0:w]) +} + +func (p *pp) fmtInt64(v int64, verb int, value interface{}) { + switch verb { + case 'b': + p.fmt.integer(v, 2, signed, ldigits) + case 'c': + p.fmtC(v) + case 'd', 'v': + p.fmt.integer(v, 10, signed, ldigits) + case 'o': + p.fmt.integer(v, 8, signed, ldigits) + case 'x': + p.fmt.integer(v, 16, signed, ldigits) + case 'U': + p.fmtUnicode(v) + case 'X': + p.fmt.integer(v, 16, signed, udigits) + default: + p.badVerb(verb, value) + } +} + +// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or +// not, as requested, by temporarily setting the sharp flag. +func (p *pp) fmt0x64(v uint64, leading0x bool) { + sharp := p.fmt.sharp + p.fmt.sharp = leading0x + p.fmt.integer(int64(v), 16, unsigned, ldigits) + p.fmt.sharp = sharp +} + +// fmtUnicode formats a uint64 in U+1234 form by +// temporarily turning on the unicode flag and tweaking the precision. +func (p *pp) fmtUnicode(v int64) { + precPresent := p.fmt.precPresent + prec := p.fmt.prec + if !precPresent { + // If prec is already set, leave it alone; otherwise 4 is minimum. + p.fmt.prec = 4 + p.fmt.precPresent = true + } + p.fmt.unicode = true // turn on U+ + p.fmt.integer(int64(v), 16, unsigned, udigits) + p.fmt.unicode = false + p.fmt.prec = prec + p.fmt.precPresent = precPresent +} + +func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { + switch verb { + case 'b': + p.fmt.integer(int64(v), 2, unsigned, ldigits) + case 'c': + p.fmtC(int64(v)) + case 'd': + p.fmt.integer(int64(v), 10, unsigned, ldigits) + case 'v': + if goSyntax { + p.fmt0x64(v, true) + } else { + p.fmt.integer(int64(v), 10, unsigned, ldigits) + } + case 'o': + p.fmt.integer(int64(v), 8, unsigned, ldigits) + case 'x': + p.fmt.integer(int64(v), 16, unsigned, ldigits) + case 'X': + p.fmt.integer(int64(v), 16, unsigned, udigits) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtFloat32(v float32, verb int, value interface{}) { + switch verb { + case 'b': + p.fmt.fmt_fb32(v) + case 'e': + p.fmt.fmt_e32(v) + case 'E': + p.fmt.fmt_E32(v) + case 'f': + p.fmt.fmt_f32(v) + case 'g', 'v': + p.fmt.fmt_g32(v) + case 'G': + p.fmt.fmt_G32(v) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtFloat64(v float64, verb int, value interface{}) { + switch verb { + case 'b': + p.fmt.fmt_fb64(v) + case 'e': + p.fmt.fmt_e64(v) + case 'E': + p.fmt.fmt_E64(v) + case 'f': + p.fmt.fmt_f64(v) + case 'g', 'v': + p.fmt.fmt_g64(v) + case 'G': + p.fmt.fmt_G64(v) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) { + switch verb { + case 'e', 'E', 'f', 'F', 'g', 'G': + p.fmt.fmt_c64(v, verb) + case 'v': + p.fmt.fmt_c64(v, 'g') + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) { + switch verb { + case 'e', 'E', 'f', 'F', 'g', 'G': + p.fmt.fmt_c128(v, verb) + case 'v': + p.fmt.fmt_c128(v, 'g') + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) { + switch verb { + case 'v': + if goSyntax { + p.fmt.fmt_q(v) + } else { + p.fmt.fmt_s(v) + } + case 's': + p.fmt.fmt_s(v) + case 'x': + p.fmt.fmt_sx(v) + case 'X': + p.fmt.fmt_sX(v) + case 'q': + p.fmt.fmt_q(v) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) { + if verb == 'v' || verb == 'd' { + if goSyntax { + p.buf.Write(bytesBytes) + } else { + p.buf.WriteByte('[') + } + for i, c := range v { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + p.printField(c, 'v', p.fmt.plus, goSyntax, depth+1) + } + if goSyntax { + p.buf.WriteByte('}') + } else { + p.buf.WriteByte(']') + } + return + } + s := string(v) + switch verb { + case 's': + p.fmt.fmt_s(s) + case 'x': + p.fmt.fmt_sx(s) + case 'X': + p.fmt.fmt_sX(s) + case 'q': + p.fmt.fmt_q(s) + default: + p.badVerb(verb, value) + } +} + +func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) { + var u uintptr + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: + u = value.Pointer() + default: + p.badVerb(verb, field) + return + } + if goSyntax { + p.add('(') + p.buf.WriteString(reflect.Typeof(field).String()) + p.add(')') + p.add('(') + if u == 0 { + p.buf.Write(nilBytes) + } else { + p.fmt0x64(uint64(u), true) + } + p.add(')') + } else { + p.fmt0x64(uint64(u), !p.fmt.sharp) + } +} + +var ( + intBits = reflect.Typeof(0).Bits() + floatBits = reflect.Typeof(0.0).Bits() + complexBits = reflect.Typeof(1i).Bits() + uintptrBits = reflect.Typeof(uintptr(0)).Bits() +) + +func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) { + if field == nil { + if verb == 'T' || verb == 'v' { + p.buf.Write(nilAngleBytes) + } else { + p.badVerb(verb, field) + } + return false + } + + // Special processing considerations. + // %T (the value's type) and %p (its address) are special; we always do them first. + switch verb { + case 'T': + p.printField(reflect.Typeof(field).String(), 's', false, false, 0) + return false + case 'p': + p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax) + return false + } + // Is it a Formatter? + if formatter, ok := field.(Formatter); ok { + formatter.Format(p, verb) + return false // this value is not a string + + } + // Must not touch flags before Formatter looks at them. + if plus { + p.fmt.plus = false + } + // If we're doing Go syntax and the field knows how to supply it, take care of it now. + if goSyntax { + p.fmt.sharp = false + if stringer, ok := field.(GoStringer); ok { + // Print the result of GoString unadorned. + p.fmtString(stringer.GoString(), 's', false, field) + return false // this value is not a string + } + } else { + // Is it a Stringer? + if stringer, ok := field.(Stringer); ok { + p.printField(stringer.String(), verb, plus, false, depth) + return false // this value is not a string + } + } + + // Some types can be done without reflection. + switch f := field.(type) { + case bool: + p.fmtBool(f, verb, field) + return false + case float32: + p.fmtFloat32(f, verb, field) + return false + case float64: + p.fmtFloat64(f, verb, field) + return false + case complex64: + p.fmtComplex64(complex64(f), verb, field) + return false + case complex128: + p.fmtComplex128(f, verb, field) + return false + case int: + p.fmtInt64(int64(f), verb, field) + return false + case int8: + p.fmtInt64(int64(f), verb, field) + return false + case int16: + p.fmtInt64(int64(f), verb, field) + return false + case int32: + p.fmtInt64(int64(f), verb, field) + return false + case int64: + p.fmtInt64(f, verb, field) + return false + case uint: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint8: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint16: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint32: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case uint64: + p.fmtUint64(f, verb, goSyntax, field) + return false + case uintptr: + p.fmtUint64(uint64(f), verb, goSyntax, field) + return false + case string: + p.fmtString(f, verb, goSyntax, field) + return verb == 's' || verb == 'v' + case []byte: + p.fmtBytes(f, verb, goSyntax, depth, field) + return verb == 's' + } + + // Need to use reflection + value := reflect.NewValue(field) + +BigSwitch: + switch f := value; f.Kind() { + case reflect.Bool: + p.fmtBool(f.Bool(), verb, field) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.fmtInt64(f.Int(), verb, field) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + p.fmtUint64(uint64(f.Uint()), verb, goSyntax, field) + case reflect.Float32, reflect.Float64: + if f.Type().Size() == 4 { + p.fmtFloat32(float32(f.Float()), verb, field) + } else { + p.fmtFloat64(float64(f.Float()), verb, field) + } + case reflect.Complex64, reflect.Complex128: + if f.Type().Size() == 8 { + p.fmtComplex64(complex64(f.Complex()), verb, field) + } else { + p.fmtComplex128(complex128(f.Complex()), verb, field) + } + case reflect.String: + p.fmtString(f.String(), verb, goSyntax, field) + case reflect.Map: + if goSyntax { + p.buf.WriteString(f.Type().String()) + p.buf.WriteByte('{') + } else { + p.buf.Write(mapBytes) + } + keys := f.MapKeys() + for i, key := range keys { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + p.printField(key.Interface(), verb, plus, goSyntax, depth+1) + p.buf.WriteByte(':') + p.printField(f.MapIndex(key).Interface(), verb, plus, goSyntax, depth+1) + } + if goSyntax { + p.buf.WriteByte('}') + } else { + p.buf.WriteByte(']') + } + case reflect.Struct: + if goSyntax { + p.buf.WriteString(reflect.Typeof(field).String()) + } + p.add('{') + v := f + t := v.Type() + for i := 0; i < v.NumField(); i++ { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + if plus || goSyntax { + if f := t.Field(i); f.Name != "" { + p.buf.WriteString(f.Name) + p.buf.WriteByte(':') + } + } + p.printField(getField(v, i).Interface(), verb, plus, goSyntax, depth+1) + } + p.buf.WriteByte('}') + case reflect.Interface: + value := f.Elem() + if !value.IsValid() { + if goSyntax { + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.Write(nilParenBytes) + } else { + p.buf.Write(nilAngleBytes) + } + } else { + return p.printField(value.Interface(), verb, plus, goSyntax, depth+1) + } + case reflect.Array, reflect.Slice: + // Byte slices are special. + if f.Type().Elem().Kind() == reflect.Uint8 { + // We know it's a slice of bytes, but we also know it does not have static type + // []byte, or it would have been caught above. Therefore we cannot convert + // it directly in the (slightly) obvious way: f.Interface().([]byte); it doesn't have + // that type, and we can't write an expression of the right type and do a + // conversion because we don't have a static way to write the right type. + // So we build a slice by hand. This is a rare case but it would be nice + // if reflection could help a little more. + bytes := make([]byte, f.Len()) + for i := range bytes { + bytes[i] = byte(f.Index(i).Uint()) + } + p.fmtBytes(bytes, verb, goSyntax, depth, field) + return verb == 's' + } + if goSyntax { + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteByte('{') + } else { + p.buf.WriteByte('[') + } + for i := 0; i < f.Len(); i++ { + if i > 0 { + if goSyntax { + p.buf.Write(commaSpaceBytes) + } else { + p.buf.WriteByte(' ') + } + } + p.printField(f.Index(i).Interface(), verb, plus, goSyntax, depth+1) + } + if goSyntax { + p.buf.WriteByte('}') + } else { + p.buf.WriteByte(']') + } + case reflect.Ptr: + v := f.Pointer() + // pointer to array or slice or struct? ok at top level + // but not embedded (avoid loops) + if v != 0 && depth == 0 { + switch a := f.Elem(); a.Kind() { + case reflect.Array, reflect.Slice: + p.buf.WriteByte('&') + p.printField(a.Interface(), verb, plus, goSyntax, depth+1) + break BigSwitch + case reflect.Struct: + p.buf.WriteByte('&') + p.printField(a.Interface(), verb, plus, goSyntax, depth+1) + break BigSwitch + } + } + if goSyntax { + p.buf.WriteByte('(') + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteByte(')') + p.buf.WriteByte('(') + if v == 0 { + p.buf.Write(nilBytes) + } else { + p.fmt0x64(uint64(v), true) + } + p.buf.WriteByte(')') + break + } + if v == 0 { + p.buf.Write(nilAngleBytes) + break + } + p.fmt0x64(uint64(v), true) + case reflect.Chan, reflect.Func, reflect.UnsafePointer: + p.fmtPointer(field, value, verb, goSyntax) + default: + p.unknownType(f) + } + return false +} + +// intFromArg gets the fieldnumth element of a. On return, isInt reports whether the argument has type int. +func intFromArg(a []interface{}, end, i, fieldnum int) (num int, isInt bool, newi, newfieldnum int) { + newi, newfieldnum = end, fieldnum + if i < end && fieldnum < len(a) { + num, isInt = a[fieldnum].(int) + newi, newfieldnum = i+1, fieldnum+1 + } + return +} + +func (p *pp) doPrintf(format string, a []interface{}) { + end := len(format) + fieldnum := 0 // we process one field per non-trivial format + for i := 0; i < end; { + lasti := i + for i < end && format[i] != '%' { + i++ + } + if i > lasti { + p.buf.WriteString(format[lasti:i]) + } + if i >= end { + // done processing format string + break + } + + // Process one verb + i++ + // flags and widths + p.fmt.clearflags() + F: + for ; i < end; i++ { + switch format[i] { + case '#': + p.fmt.sharp = true + case '0': + p.fmt.zero = true + case '+': + p.fmt.plus = true + case '-': + p.fmt.minus = true + case ' ': + p.fmt.space = true + default: + break F + } + } + // do we have width? + if i < end && format[i] == '*' { + p.fmt.wid, p.fmt.widPresent, i, fieldnum = intFromArg(a, end, i, fieldnum) + if !p.fmt.widPresent { + p.buf.Write(widthBytes) + } + } else { + p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end) + } + // do we have precision? + if i < end && format[i] == '.' { + if format[i+1] == '*' { + p.fmt.prec, p.fmt.precPresent, i, fieldnum = intFromArg(a, end, i+1, fieldnum) + if !p.fmt.precPresent { + p.buf.Write(precBytes) + } + } else { + p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i+1, end) + } + } + if i >= end { + p.buf.Write(noVerbBytes) + continue + } + c, w := utf8.DecodeRuneInString(format[i:]) + i += w + // percent is special - absorbs no operand + if c == '%' { + p.buf.WriteByte('%') // We ignore width and prec. + continue + } + if fieldnum >= len(a) { // out of operands + p.buf.WriteByte('%') + p.add(c) + p.buf.Write(missingBytes) + continue + } + field := a[fieldnum] + fieldnum++ + + goSyntax := c == 'v' && p.fmt.sharp + plus := c == 'v' && p.fmt.plus + p.printField(field, c, plus, goSyntax, 0) + } + + if fieldnum < len(a) { + p.buf.Write(extraBytes) + for ; fieldnum < len(a); fieldnum++ { + field := a[fieldnum] + if field != nil { + p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteByte('=') + } + p.printField(field, 'v', false, false, 0) + if fieldnum+1 < len(a) { + p.buf.Write(commaSpaceBytes) + } + } + p.buf.WriteByte(')') + } +} + +func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) { + prevString := false + for fieldnum := 0; fieldnum < len(a); fieldnum++ { + p.fmt.clearflags() + // always add spaces if we're doing println + field := a[fieldnum] + if fieldnum > 0 { + isString := field != nil && reflect.Typeof(field).Kind() == reflect.String + if addspace || !isString && !prevString { + p.buf.WriteByte(' ') + } + } + prevString = p.printField(field, 'v', false, false, 0) + } + if addnewline { + p.buf.WriteByte('\n') + } +} diff --git a/src/cmd/gofix/testdata/reflect.quick.go.in b/src/cmd/gofix/testdata/reflect.quick.go.in new file mode 100644 index 000000000..a5568b048 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.quick.go.in @@ -0,0 +1,364 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package implements utility functions to help with black box testing. +package quick + +import ( + "flag" + "fmt" + "math" + "os" + "rand" + "reflect" + "strings" +) + +var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check") + +// A Generator can generate random values of its own type. +type Generator interface { + // Generate returns a random instance of the type on which it is a + // method using the size as a size hint. + Generate(rand *rand.Rand, size int) reflect.Value +} + +// randFloat32 generates a random float taking the full range of a float32. +func randFloat32(rand *rand.Rand) float32 { + f := rand.Float64() * math.MaxFloat32 + if rand.Int()&1 == 1 { + f = -f + } + return float32(f) +} + +// randFloat64 generates a random float taking the full range of a float64. +func randFloat64(rand *rand.Rand) float64 { + f := rand.Float64() + if rand.Int()&1 == 1 { + f = -f + } + return f +} + +// randInt64 returns a random integer taking half the range of an int64. +func randInt64(rand *rand.Rand) int64 { return rand.Int63() - 1<<62 } + +// complexSize is the maximum length of arbitrary values that contain other +// values. +const complexSize = 50 + +// Value returns an arbitrary value of the given type. +// If the type implements the Generator interface, that will be used. +// Note: in order to create arbitrary values for structs, all the members must be public. +func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { + if m, ok := reflect.MakeZero(t).Interface().(Generator); ok { + return m.Generate(rand, complexSize), true + } + + switch concrete := t.(type) { + case *reflect.BoolType: + return reflect.NewValue(rand.Int()&1 == 0), true + case *reflect.FloatType, *reflect.IntType, *reflect.UintType, *reflect.ComplexType: + switch t.Kind() { + case reflect.Float32: + return reflect.NewValue(randFloat32(rand)), true + case reflect.Float64: + return reflect.NewValue(randFloat64(rand)), true + case reflect.Complex64: + return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true + case reflect.Complex128: + return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true + case reflect.Int16: + return reflect.NewValue(int16(randInt64(rand))), true + case reflect.Int32: + return reflect.NewValue(int32(randInt64(rand))), true + case reflect.Int64: + return reflect.NewValue(randInt64(rand)), true + case reflect.Int8: + return reflect.NewValue(int8(randInt64(rand))), true + case reflect.Int: + return reflect.NewValue(int(randInt64(rand))), true + case reflect.Uint16: + return reflect.NewValue(uint16(randInt64(rand))), true + case reflect.Uint32: + return reflect.NewValue(uint32(randInt64(rand))), true + case reflect.Uint64: + return reflect.NewValue(uint64(randInt64(rand))), true + case reflect.Uint8: + return reflect.NewValue(uint8(randInt64(rand))), true + case reflect.Uint: + return reflect.NewValue(uint(randInt64(rand))), true + case reflect.Uintptr: + return reflect.NewValue(uintptr(randInt64(rand))), true + } + case *reflect.MapType: + numElems := rand.Intn(complexSize) + m := reflect.MakeMap(concrete) + for i := 0; i < numElems; i++ { + key, ok1 := Value(concrete.Key(), rand) + value, ok2 := Value(concrete.Elem(), rand) + if !ok1 || !ok2 { + return nil, false + } + m.SetElem(key, value) + } + return m, true + case *reflect.PtrType: + v, ok := Value(concrete.Elem(), rand) + if !ok { + return nil, false + } + p := reflect.MakeZero(concrete) + p.(*reflect.PtrValue).PointTo(v) + return p, true + case *reflect.SliceType: + numElems := rand.Intn(complexSize) + s := reflect.MakeSlice(concrete, numElems, numElems) + for i := 0; i < numElems; i++ { + v, ok := Value(concrete.Elem(), rand) + if !ok { + return nil, false + } + s.Elem(i).SetValue(v) + } + return s, true + case *reflect.StringType: + numChars := rand.Intn(complexSize) + codePoints := make([]int, numChars) + for i := 0; i < numChars; i++ { + codePoints[i] = rand.Intn(0x10ffff) + } + return reflect.NewValue(string(codePoints)), true + case *reflect.StructType: + s := reflect.MakeZero(t).(*reflect.StructValue) + for i := 0; i < s.NumField(); i++ { + v, ok := Value(concrete.Field(i).Type, rand) + if !ok { + return nil, false + } + s.Field(i).SetValue(v) + } + return s, true + default: + return nil, false + } + + return +} + +// A Config structure contains options for running a test. +type Config struct { + // MaxCount sets the maximum number of iterations. If zero, + // MaxCountScale is used. + MaxCount int + // MaxCountScale is a non-negative scale factor applied to the default + // maximum. If zero, the default is unchanged. + MaxCountScale float64 + // If non-nil, rand is a source of random numbers. Otherwise a default + // pseudo-random source will be used. + Rand *rand.Rand + // If non-nil, Values is a function which generates a slice of arbitrary + // Values that are congruent with the arguments to the function being + // tested. Otherwise, Values is used to generate the values. + Values func([]reflect.Value, *rand.Rand) +} + +var defaultConfig Config + +// getRand returns the *rand.Rand to use for a given Config. +func (c *Config) getRand() *rand.Rand { + if c.Rand == nil { + return rand.New(rand.NewSource(0)) + } + return c.Rand +} + +// getMaxCount returns the maximum number of iterations to run for a given +// Config. +func (c *Config) getMaxCount() (maxCount int) { + maxCount = c.MaxCount + if maxCount == 0 { + if c.MaxCountScale != 0 { + maxCount = int(c.MaxCountScale * float64(*defaultMaxCount)) + } else { + maxCount = *defaultMaxCount + } + } + + return +} + +// A SetupError is the result of an error in the way that check is being +// used, independent of the functions being tested. +type SetupError string + +func (s SetupError) String() string { return string(s) } + +// A CheckError is the result of Check finding an error. +type CheckError struct { + Count int + In []interface{} +} + +func (s *CheckError) String() string { + return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In)) +} + +// A CheckEqualError is the result CheckEqual finding an error. +type CheckEqualError struct { + CheckError + Out1 []interface{} + Out2 []interface{} +} + +func (s *CheckEqualError) String() string { + return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2)) +} + +// Check looks for an input to f, any function that returns bool, +// such that f returns false. It calls f repeatedly, with arbitrary +// values for each argument. If f returns false on a given input, +// Check returns that input as a *CheckError. +// For example: +// +// func TestOddMultipleOfThree(t *testing.T) { +// f := func(x int) bool { +// y := OddMultipleOfThree(x) +// return y%2 == 1 && y%3 == 0 +// } +// if err := quick.Check(f, nil); err != nil { +// t.Error(err) +// } +// } +func Check(function interface{}, config *Config) (err os.Error) { + if config == nil { + config = &defaultConfig + } + + f, fType, ok := functionAndType(function) + if !ok { + err = SetupError("argument is not a function") + return + } + + if fType.NumOut() != 1 { + err = SetupError("function returns more than one value.") + return + } + if _, ok := fType.Out(0).(*reflect.BoolType); !ok { + err = SetupError("function does not return a bool") + return + } + + arguments := make([]reflect.Value, fType.NumIn()) + rand := config.getRand() + maxCount := config.getMaxCount() + + for i := 0; i < maxCount; i++ { + err = arbitraryValues(arguments, fType, config, rand) + if err != nil { + return + } + + if !f.Call(arguments)[0].(*reflect.BoolValue).Get() { + err = &CheckError{i + 1, toInterfaces(arguments)} + return + } + } + + return +} + +// CheckEqual looks for an input on which f and g return different results. +// It calls f and g repeatedly with arbitrary values for each argument. +// If f and g return different answers, CheckEqual returns a *CheckEqualError +// describing the input and the outputs. +func CheckEqual(f, g interface{}, config *Config) (err os.Error) { + if config == nil { + config = &defaultConfig + } + + x, xType, ok := functionAndType(f) + if !ok { + err = SetupError("f is not a function") + return + } + y, yType, ok := functionAndType(g) + if !ok { + err = SetupError("g is not a function") + return + } + + if xType != yType { + err = SetupError("functions have different types") + return + } + + arguments := make([]reflect.Value, xType.NumIn()) + rand := config.getRand() + maxCount := config.getMaxCount() + + for i := 0; i < maxCount; i++ { + err = arbitraryValues(arguments, xType, config, rand) + if err != nil { + return + } + + xOut := toInterfaces(x.Call(arguments)) + yOut := toInterfaces(y.Call(arguments)) + + if !reflect.DeepEqual(xOut, yOut) { + err = &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut} + return + } + } + + return +} + +// arbitraryValues writes Values to args such that args contains Values +// suitable for calling f. +func arbitraryValues(args []reflect.Value, f *reflect.FuncType, config *Config, rand *rand.Rand) (err os.Error) { + if config.Values != nil { + config.Values(args, rand) + return + } + + for j := 0; j < len(args); j++ { + var ok bool + args[j], ok = Value(f.In(j), rand) + if !ok { + err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j)) + return + } + } + + return +} + +func functionAndType(f interface{}) (v *reflect.FuncValue, t *reflect.FuncType, ok bool) { + v, ok = reflect.NewValue(f).(*reflect.FuncValue) + if !ok { + return + } + t = v.Type().(*reflect.FuncType) + return +} + +func toInterfaces(values []reflect.Value) []interface{} { + ret := make([]interface{}, len(values)) + for i, v := range values { + ret[i] = v.Interface() + } + return ret +} + +func toString(interfaces []interface{}) string { + s := make([]string, len(interfaces)) + for i, v := range interfaces { + s[i] = fmt.Sprintf("%#v", v) + } + return strings.Join(s, ", ") +} diff --git a/src/cmd/gofix/testdata/reflect.quick.go.out b/src/cmd/gofix/testdata/reflect.quick.go.out new file mode 100644 index 000000000..152dbad32 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.quick.go.out @@ -0,0 +1,365 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package implements utility functions to help with black box testing. +package quick + +import ( + "flag" + "fmt" + "math" + "os" + "rand" + "reflect" + "strings" +) + +var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check") + +// A Generator can generate random values of its own type. +type Generator interface { + // Generate returns a random instance of the type on which it is a + // method using the size as a size hint. + Generate(rand *rand.Rand, size int) reflect.Value +} + +// randFloat32 generates a random float taking the full range of a float32. +func randFloat32(rand *rand.Rand) float32 { + f := rand.Float64() * math.MaxFloat32 + if rand.Int()&1 == 1 { + f = -f + } + return float32(f) +} + +// randFloat64 generates a random float taking the full range of a float64. +func randFloat64(rand *rand.Rand) float64 { + f := rand.Float64() + if rand.Int()&1 == 1 { + f = -f + } + return f +} + +// randInt64 returns a random integer taking half the range of an int64. +func randInt64(rand *rand.Rand) int64 { return rand.Int63() - 1<<62 } + +// complexSize is the maximum length of arbitrary values that contain other +// values. +const complexSize = 50 + +// Value returns an arbitrary value of the given type. +// If the type implements the Generator interface, that will be used. +// Note: in order to create arbitrary values for structs, all the members must be public. +func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { + if m, ok := reflect.Zero(t).Interface().(Generator); ok { + return m.Generate(rand, complexSize), true + } + + switch concrete := t; concrete.Kind() { + case reflect.Bool: + return reflect.NewValue(rand.Int()&1 == 0), true + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Complex64, reflect.Complex128: + switch t.Kind() { + case reflect.Float32: + return reflect.NewValue(randFloat32(rand)), true + case reflect.Float64: + return reflect.NewValue(randFloat64(rand)), true + case reflect.Complex64: + return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true + case reflect.Complex128: + return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true + case reflect.Int16: + return reflect.NewValue(int16(randInt64(rand))), true + case reflect.Int32: + return reflect.NewValue(int32(randInt64(rand))), true + case reflect.Int64: + return reflect.NewValue(randInt64(rand)), true + case reflect.Int8: + return reflect.NewValue(int8(randInt64(rand))), true + case reflect.Int: + return reflect.NewValue(int(randInt64(rand))), true + case reflect.Uint16: + return reflect.NewValue(uint16(randInt64(rand))), true + case reflect.Uint32: + return reflect.NewValue(uint32(randInt64(rand))), true + case reflect.Uint64: + return reflect.NewValue(uint64(randInt64(rand))), true + case reflect.Uint8: + return reflect.NewValue(uint8(randInt64(rand))), true + case reflect.Uint: + return reflect.NewValue(uint(randInt64(rand))), true + case reflect.Uintptr: + return reflect.NewValue(uintptr(randInt64(rand))), true + } + case reflect.Map: + numElems := rand.Intn(complexSize) + m := reflect.MakeMap(concrete) + for i := 0; i < numElems; i++ { + key, ok1 := Value(concrete.Key(), rand) + value, ok2 := Value(concrete.Elem(), rand) + if !ok1 || !ok2 { + return reflect.Value{}, false + } + m.SetMapIndex(key, value) + } + return m, true + case reflect.Ptr: + v, ok := Value(concrete.Elem(), rand) + if !ok { + return reflect.Value{}, false + } + p := reflect.Zero(concrete) + p.Set(v.Addr()) + return p, true + case reflect.Slice: + numElems := rand.Intn(complexSize) + s := reflect.MakeSlice(concrete, numElems, numElems) + for i := 0; i < numElems; i++ { + v, ok := Value(concrete.Elem(), rand) + if !ok { + return reflect.Value{}, false + } + s.Index(i).Set(v) + } + return s, true + case reflect.String: + numChars := rand.Intn(complexSize) + codePoints := make([]int, numChars) + for i := 0; i < numChars; i++ { + codePoints[i] = rand.Intn(0x10ffff) + } + return reflect.NewValue(string(codePoints)), true + case reflect.Struct: + s := reflect.Zero(t) + for i := 0; i < s.NumField(); i++ { + v, ok := Value(concrete.Field(i).Type, rand) + if !ok { + return reflect.Value{}, false + } + s.Field(i).Set(v) + } + return s, true + default: + return reflect.Value{}, false + } + + return +} + +// A Config structure contains options for running a test. +type Config struct { + // MaxCount sets the maximum number of iterations. If zero, + // MaxCountScale is used. + MaxCount int + // MaxCountScale is a non-negative scale factor applied to the default + // maximum. If zero, the default is unchanged. + MaxCountScale float64 + // If non-nil, rand is a source of random numbers. Otherwise a default + // pseudo-random source will be used. + Rand *rand.Rand + // If non-nil, Values is a function which generates a slice of arbitrary + // Values that are congruent with the arguments to the function being + // tested. Otherwise, Values is used to generate the values. + Values func([]reflect.Value, *rand.Rand) +} + +var defaultConfig Config + +// getRand returns the *rand.Rand to use for a given Config. +func (c *Config) getRand() *rand.Rand { + if c.Rand == nil { + return rand.New(rand.NewSource(0)) + } + return c.Rand +} + +// getMaxCount returns the maximum number of iterations to run for a given +// Config. +func (c *Config) getMaxCount() (maxCount int) { + maxCount = c.MaxCount + if maxCount == 0 { + if c.MaxCountScale != 0 { + maxCount = int(c.MaxCountScale * float64(*defaultMaxCount)) + } else { + maxCount = *defaultMaxCount + } + } + + return +} + +// A SetupError is the result of an error in the way that check is being +// used, independent of the functions being tested. +type SetupError string + +func (s SetupError) String() string { return string(s) } + +// A CheckError is the result of Check finding an error. +type CheckError struct { + Count int + In []interface{} +} + +func (s *CheckError) String() string { + return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In)) +} + +// A CheckEqualError is the result CheckEqual finding an error. +type CheckEqualError struct { + CheckError + Out1 []interface{} + Out2 []interface{} +} + +func (s *CheckEqualError) String() string { + return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2)) +} + +// Check looks for an input to f, any function that returns bool, +// such that f returns false. It calls f repeatedly, with arbitrary +// values for each argument. If f returns false on a given input, +// Check returns that input as a *CheckError. +// For example: +// +// func TestOddMultipleOfThree(t *testing.T) { +// f := func(x int) bool { +// y := OddMultipleOfThree(x) +// return y%2 == 1 && y%3 == 0 +// } +// if err := quick.Check(f, nil); err != nil { +// t.Error(err) +// } +// } +func Check(function interface{}, config *Config) (err os.Error) { + if config == nil { + config = &defaultConfig + } + + f, fType, ok := functionAndType(function) + if !ok { + err = SetupError("argument is not a function") + return + } + + if fType.NumOut() != 1 { + err = SetupError("function returns more than one value.") + return + } + if fType.Out(0).Kind() != reflect.Bool { + err = SetupError("function does not return a bool") + return + } + + arguments := make([]reflect.Value, fType.NumIn()) + rand := config.getRand() + maxCount := config.getMaxCount() + + for i := 0; i < maxCount; i++ { + err = arbitraryValues(arguments, fType, config, rand) + if err != nil { + return + } + + if !f.Call(arguments)[0].Bool() { + err = &CheckError{i + 1, toInterfaces(arguments)} + return + } + } + + return +} + +// CheckEqual looks for an input on which f and g return different results. +// It calls f and g repeatedly with arbitrary values for each argument. +// If f and g return different answers, CheckEqual returns a *CheckEqualError +// describing the input and the outputs. +func CheckEqual(f, g interface{}, config *Config) (err os.Error) { + if config == nil { + config = &defaultConfig + } + + x, xType, ok := functionAndType(f) + if !ok { + err = SetupError("f is not a function") + return + } + y, yType, ok := functionAndType(g) + if !ok { + err = SetupError("g is not a function") + return + } + + if xType != yType { + err = SetupError("functions have different types") + return + } + + arguments := make([]reflect.Value, xType.NumIn()) + rand := config.getRand() + maxCount := config.getMaxCount() + + for i := 0; i < maxCount; i++ { + err = arbitraryValues(arguments, xType, config, rand) + if err != nil { + return + } + + xOut := toInterfaces(x.Call(arguments)) + yOut := toInterfaces(y.Call(arguments)) + + if !reflect.DeepEqual(xOut, yOut) { + err = &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut} + return + } + } + + return +} + +// arbitraryValues writes Values to args such that args contains Values +// suitable for calling f. +func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand *rand.Rand) (err os.Error) { + if config.Values != nil { + config.Values(args, rand) + return + } + + for j := 0; j < len(args); j++ { + var ok bool + args[j], ok = Value(f.In(j), rand) + if !ok { + err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j)) + return + } + } + + return +} + +func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) { + v = reflect.NewValue(f) + ok = v.Kind() == reflect.Func + if !ok { + return + } + t = v.Type() + return +} + +func toInterfaces(values []reflect.Value) []interface{} { + ret := make([]interface{}, len(values)) + for i, v := range values { + ret[i] = v.Interface() + } + return ret +} + +func toString(interfaces []interface{}) string { + s := make([]string, len(interfaces)) + for i, v := range interfaces { + s[i] = fmt.Sprintf("%#v", v) + } + return strings.Join(s, ", ") +} diff --git a/src/cmd/gofix/testdata/reflect.read.go.in b/src/cmd/gofix/testdata/reflect.read.go.in new file mode 100644 index 000000000..9ae3bb8ee --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.read.go.in @@ -0,0 +1,620 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "fmt" + "io" + "os" + "reflect" + "strconv" + "strings" + "unicode" + "utf8" +) + +// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: +// an XML element is an order-dependent collection of anonymous +// values, while a data structure is an order-independent collection +// of named values. +// See package json for a textual representation more suitable +// to data structures. + +// Unmarshal parses an XML element from r and uses the +// reflect library to fill in an arbitrary struct, slice, or string +// pointed at by val. Well-formed data that does not fit +// into val is discarded. +// +// For example, given these definitions: +// +// type Email struct { +// Where string "attr" +// Addr string +// } +// +// type Result struct { +// XMLName xml.Name "result" +// Name string +// Phone string +// Email []Email +// Groups []string "group>value" +// } +// +// result := Result{Name: "name", Phone: "phone", Email: nil} +// +// unmarshalling the XML input +// +// <result> +// <email where="home"> +// <addr>gre@example.com</addr> +// </email> +// <email where='work'> +// <addr>gre@work.com</addr> +// </email> +// <name>Grace R. Emlin</name> +// <group> +// <value>Friends</value> +// <value>Squash</value> +// </group> +// <address>123 Main Street</address> +// </result> +// +// via Unmarshal(r, &result) is equivalent to assigning +// +// r = Result{xml.Name{"", "result"}, +// "Grace R. Emlin", // name +// "phone", // no phone given +// []Email{ +// Email{"home", "gre@example.com"}, +// Email{"work", "gre@work.com"}, +// }, +// []string{"Friends", "Squash"}, +// } +// +// Note that the field r.Phone has not been modified and +// that the XML <address> element was discarded. Also, the field +// Groups was assigned considering the element path provided in the +// field tag. +// +// Because Unmarshal uses the reflect package, it can only +// assign to upper case fields. Unmarshal uses a case-insensitive +// comparison to match XML element names to struct field names. +// +// Unmarshal maps an XML element to a struct using the following rules: +// +// * If the struct has a field of type []byte or string with tag "innerxml", +// Unmarshal accumulates the raw XML nested inside the element +// in that field. The rest of the rules still apply. +// +// * If the struct has a field named XMLName of type xml.Name, +// Unmarshal records the element name in that field. +// +// * If the XMLName field has an associated tag string of the form +// "tag" or "namespace-URL tag", the XML element must have +// the given tag (and, optionally, name space) or else Unmarshal +// returns an error. +// +// * If the XML element has an attribute whose name matches a +// struct field of type string with tag "attr", Unmarshal records +// the attribute value in that field. +// +// * If the XML element contains character data, that data is +// accumulated in the first struct field that has tag "chardata". +// The struct field may have type []byte or string. +// If there is no such field, the character data is discarded. +// +// * If the XML element contains a sub-element whose name matches +// the prefix of a struct field tag formatted as "a>b>c", unmarshal +// will descend into the XML structure looking for elements with the +// given names, and will map the innermost elements to that struct field. +// A struct field tag starting with ">" is equivalent to one starting +// with the field name followed by ">". +// +// * If the XML element contains a sub-element whose name +// matches a struct field whose tag is neither "attr" nor "chardata", +// Unmarshal maps the sub-element to that struct field. +// Otherwise, if the struct has a field named Any, unmarshal +// maps the sub-element to that struct field. +// +// Unmarshal maps an XML element to a string or []byte by saving the +// concatenation of that element's character data in the string or []byte. +// +// Unmarshal maps an XML element to a slice by extending the length +// of the slice and mapping the element to the newly created value. +// +// Unmarshal maps an XML element to a bool by setting it to the boolean +// value represented by the string. +// +// Unmarshal maps an XML element to an integer or floating-point +// field by setting the field to the result of interpreting the string +// value in decimal. There is no check for overflow. +// +// Unmarshal maps an XML element to an xml.Name by recording the +// element name. +// +// Unmarshal maps an XML element to a pointer by setting the pointer +// to a freshly allocated value and then mapping the element to that value. +// +func Unmarshal(r io.Reader, val interface{}) os.Error { + v, ok := reflect.NewValue(val).(*reflect.PtrValue) + if !ok { + return os.NewError("non-pointer passed to Unmarshal") + } + p := NewParser(r) + elem := v.Elem() + err := p.unmarshal(elem, nil) + if err != nil { + return err + } + return nil +} + +// An UnmarshalError represents an error in the unmarshalling process. +type UnmarshalError string + +func (e UnmarshalError) String() string { return string(e) } + +// A TagPathError represents an error in the unmarshalling process +// caused by the use of field tags with conflicting paths. +type TagPathError struct { + Struct reflect.Type + Field1, Tag1 string + Field2, Tag2 string +} + +func (e *TagPathError) String() string { + return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2) +} + +// The Parser's Unmarshal method is like xml.Unmarshal +// except that it can be passed a pointer to the initial start element, +// useful when a client reads some raw XML tokens itself +// but also defers to Unmarshal for some elements. +// Passing a nil start element indicates that Unmarshal should +// read the token stream to find the start element. +func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error { + v, ok := reflect.NewValue(val).(*reflect.PtrValue) + if !ok { + return os.NewError("non-pointer passed to Unmarshal") + } + return p.unmarshal(v.Elem(), start) +} + +// fieldName strips invalid characters from an XML name +// to create a valid Go struct name. It also converts the +// name to lower case letters. +func fieldName(original string) string { + + var i int + //remove leading underscores + for i = 0; i < len(original) && original[i] == '_'; i++ { + } + + return strings.Map( + func(x int) int { + if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) { + return unicode.ToLower(x) + } + return -1 + }, + original[i:]) +} + +// Unmarshal a single XML element into val. +func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { + // Find start element if we need it. + if start == nil { + for { + tok, err := p.Token() + if err != nil { + return err + } + if t, ok := tok.(StartElement); ok { + start = &t + break + } + } + } + + if pv, ok := val.(*reflect.PtrValue); ok { + if pv.Get() == 0 { + zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem()) + pv.PointTo(zv) + val = zv + } else { + val = pv.Elem() + } + } + + var ( + data []byte + saveData reflect.Value + comment []byte + saveComment reflect.Value + saveXML reflect.Value + saveXMLIndex int + saveXMLData []byte + sv *reflect.StructValue + styp *reflect.StructType + fieldPaths map[string]pathInfo + ) + + switch v := val.(type) { + default: + return os.ErrorString("unknown type " + v.Type().String()) + + case *reflect.SliceValue: + typ := v.Type().(*reflect.SliceType) + if typ.Elem().Kind() == reflect.Uint8 { + // []byte + saveData = v + break + } + + // Slice of element values. + // Grow slice. + n := v.Len() + if n >= v.Cap() { + ncap := 2 * n + if ncap < 4 { + ncap = 4 + } + new := reflect.MakeSlice(typ, n, ncap) + reflect.Copy(new, v) + v.Set(new) + } + v.SetLen(n + 1) + + // Recur to read element into slice. + if err := p.unmarshal(v.Elem(n), start); err != nil { + v.SetLen(n) + return err + } + return nil + + case *reflect.BoolValue, *reflect.FloatValue, *reflect.IntValue, *reflect.UintValue, *reflect.StringValue: + saveData = v + + case *reflect.StructValue: + if _, ok := v.Interface().(Name); ok { + v.Set(reflect.NewValue(start.Name).(*reflect.StructValue)) + break + } + + sv = v + typ := sv.Type().(*reflect.StructType) + styp = typ + // Assign name. + if f, ok := typ.FieldByName("XMLName"); ok { + // Validate element name. + if f.Tag != "" { + tag := f.Tag + ns := "" + i := strings.LastIndex(tag, " ") + if i >= 0 { + ns, tag = tag[0:i], tag[i+1:] + } + if tag != start.Name.Local { + return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">") + } + if ns != "" && ns != start.Name.Space { + e := "expected element <" + tag + "> in name space " + ns + " but have " + if start.Name.Space == "" { + e += "no name space" + } else { + e += start.Name.Space + } + return UnmarshalError(e) + } + } + + // Save + v := sv.FieldByIndex(f.Index) + if _, ok := v.Interface().(Name); !ok { + return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name") + } + v.(*reflect.StructValue).Set(reflect.NewValue(start.Name).(*reflect.StructValue)) + } + + // Assign attributes. + // Also, determine whether we need to save character data or comments. + for i, n := 0, typ.NumField(); i < n; i++ { + f := typ.Field(i) + switch f.Tag { + case "attr": + strv, ok := sv.FieldByIndex(f.Index).(*reflect.StringValue) + if !ok { + return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string") + } + // Look for attribute. + val := "" + k := strings.ToLower(f.Name) + for _, a := range start.Attr { + if fieldName(a.Name.Local) == k { + val = a.Value + break + } + } + strv.Set(val) + + case "comment": + if saveComment == nil { + saveComment = sv.FieldByIndex(f.Index) + } + + case "chardata": + if saveData == nil { + saveData = sv.FieldByIndex(f.Index) + } + + case "innerxml": + if saveXML == nil { + saveXML = sv.FieldByIndex(f.Index) + if p.saved == nil { + saveXMLIndex = 0 + p.saved = new(bytes.Buffer) + } else { + saveXMLIndex = p.savedOffset() + } + } + + default: + if strings.Contains(f.Tag, ">") { + if fieldPaths == nil { + fieldPaths = make(map[string]pathInfo) + } + path := strings.ToLower(f.Tag) + if strings.HasPrefix(f.Tag, ">") { + path = strings.ToLower(f.Name) + path + } + if strings.HasSuffix(f.Tag, ">") { + path = path[:len(path)-1] + } + err := addFieldPath(sv, fieldPaths, path, f.Index) + if err != nil { + return err + } + } + } + } + } + + // Find end element. + // Process sub-elements along the way. +Loop: + for { + var savedOffset int + if saveXML != nil { + savedOffset = p.savedOffset() + } + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + // Sub-element. + // Look up by tag name. + if sv != nil { + k := fieldName(t.Name.Local) + + if fieldPaths != nil { + if _, found := fieldPaths[k]; found { + if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil { + return err + } + continue Loop + } + } + + match := func(s string) bool { + // check if the name matches ignoring case + if strings.ToLower(s) != k { + return false + } + // now check that it's public + c, _ := utf8.DecodeRuneInString(s) + return unicode.IsUpper(c) + } + + f, found := styp.FieldByNameFunc(match) + if !found { // fall back to mop-up field named "Any" + f, found = styp.FieldByName("Any") + } + if found { + if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil { + return err + } + continue Loop + } + } + // Not saving sub-element but still have to skip over it. + if err := p.Skip(); err != nil { + return err + } + + case EndElement: + if saveXML != nil { + saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset] + if saveXMLIndex == 0 { + p.saved = nil + } + } + break Loop + + case CharData: + if saveData != nil { + data = append(data, t...) + } + + case Comment: + if saveComment != nil { + comment = append(comment, t...) + } + } + } + + var err os.Error + // Helper functions for integer and unsigned integer conversions + var itmp int64 + getInt64 := func() bool { + itmp, err = strconv.Atoi64(string(data)) + // TODO: should check sizes + return err == nil + } + var utmp uint64 + getUint64 := func() bool { + utmp, err = strconv.Atoui64(string(data)) + // TODO: check for overflow? + return err == nil + } + var ftmp float64 + getFloat64 := func() bool { + ftmp, err = strconv.Atof64(string(data)) + // TODO: check for overflow? + return err == nil + } + + // Save accumulated data and comments + switch t := saveData.(type) { + case nil: + // Probably a comment, handled below + default: + return os.ErrorString("cannot happen: unknown type " + t.Type().String()) + case *reflect.IntValue: + if !getInt64() { + return err + } + t.Set(itmp) + case *reflect.UintValue: + if !getUint64() { + return err + } + t.Set(utmp) + case *reflect.FloatValue: + if !getFloat64() { + return err + } + t.Set(ftmp) + case *reflect.BoolValue: + value, err := strconv.Atob(strings.TrimSpace(string(data))) + if err != nil { + return err + } + t.Set(value) + case *reflect.StringValue: + t.Set(string(data)) + case *reflect.SliceValue: + t.Set(reflect.NewValue(data).(*reflect.SliceValue)) + } + + switch t := saveComment.(type) { + case *reflect.StringValue: + t.Set(string(comment)) + case *reflect.SliceValue: + t.Set(reflect.NewValue(comment).(*reflect.SliceValue)) + } + + switch t := saveXML.(type) { + case *reflect.StringValue: + t.Set(string(saveXMLData)) + case *reflect.SliceValue: + t.Set(reflect.NewValue(saveXMLData).(*reflect.SliceValue)) + } + + return nil +} + +type pathInfo struct { + fieldIdx []int + complete bool +} + +// addFieldPath takes an element path such as "a>b>c" and fills the +// paths map with all paths leading to it ("a", "a>b", and "a>b>c"). +// It is okay for paths to share a common, shorter prefix but not ok +// for one path to itself be a prefix of another. +func addFieldPath(sv *reflect.StructValue, paths map[string]pathInfo, path string, fieldIdx []int) os.Error { + if info, found := paths[path]; found { + return tagError(sv, info.fieldIdx, fieldIdx) + } + paths[path] = pathInfo{fieldIdx, true} + for { + i := strings.LastIndex(path, ">") + if i < 0 { + break + } + path = path[:i] + if info, found := paths[path]; found { + if info.complete { + return tagError(sv, info.fieldIdx, fieldIdx) + } + } else { + paths[path] = pathInfo{fieldIdx, false} + } + } + return nil + +} + +func tagError(sv *reflect.StructValue, idx1 []int, idx2 []int) os.Error { + t := sv.Type().(*reflect.StructType) + f1 := t.FieldByIndex(idx1) + f2 := t.FieldByIndex(idx2) + return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag} +} + +// unmarshalPaths walks down an XML structure looking for +// wanted paths, and calls unmarshal on them. +func (p *Parser) unmarshalPaths(sv *reflect.StructValue, paths map[string]pathInfo, path string, start *StartElement) os.Error { + if info, _ := paths[path]; info.complete { + return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start) + } + for { + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + k := path + ">" + fieldName(t.Name.Local) + if _, found := paths[k]; found { + if err := p.unmarshalPaths(sv, paths, k, &t); err != nil { + return err + } + continue + } + if err := p.Skip(); err != nil { + return err + } + case EndElement: + return nil + } + } + panic("unreachable") +} + +// Have already read a start element. +// Read tokens until we find the end element. +// Token is taking care of making sure the +// end element matches the start element we saw. +func (p *Parser) Skip() os.Error { + for { + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + if err := p.Skip(); err != nil { + return err + } + case EndElement: + return nil + } + } + panic("unreachable") +} diff --git a/src/cmd/gofix/testdata/reflect.read.go.out b/src/cmd/gofix/testdata/reflect.read.go.out new file mode 100644 index 000000000..a3ddb9d4c --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.read.go.out @@ -0,0 +1,620 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "fmt" + "io" + "os" + "reflect" + "strconv" + "strings" + "unicode" + "utf8" +) + +// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: +// an XML element is an order-dependent collection of anonymous +// values, while a data structure is an order-independent collection +// of named values. +// See package json for a textual representation more suitable +// to data structures. + +// Unmarshal parses an XML element from r and uses the +// reflect library to fill in an arbitrary struct, slice, or string +// pointed at by val. Well-formed data that does not fit +// into val is discarded. +// +// For example, given these definitions: +// +// type Email struct { +// Where string "attr" +// Addr string +// } +// +// type Result struct { +// XMLName xml.Name "result" +// Name string +// Phone string +// Email []Email +// Groups []string "group>value" +// } +// +// result := Result{Name: "name", Phone: "phone", Email: nil} +// +// unmarshalling the XML input +// +// <result> +// <email where="home"> +// <addr>gre@example.com</addr> +// </email> +// <email where='work'> +// <addr>gre@work.com</addr> +// </email> +// <name>Grace R. Emlin</name> +// <group> +// <value>Friends</value> +// <value>Squash</value> +// </group> +// <address>123 Main Street</address> +// </result> +// +// via Unmarshal(r, &result) is equivalent to assigning +// +// r = Result{xml.Name{"", "result"}, +// "Grace R. Emlin", // name +// "phone", // no phone given +// []Email{ +// Email{"home", "gre@example.com"}, +// Email{"work", "gre@work.com"}, +// }, +// []string{"Friends", "Squash"}, +// } +// +// Note that the field r.Phone has not been modified and +// that the XML <address> element was discarded. Also, the field +// Groups was assigned considering the element path provided in the +// field tag. +// +// Because Unmarshal uses the reflect package, it can only +// assign to upper case fields. Unmarshal uses a case-insensitive +// comparison to match XML element names to struct field names. +// +// Unmarshal maps an XML element to a struct using the following rules: +// +// * If the struct has a field of type []byte or string with tag "innerxml", +// Unmarshal accumulates the raw XML nested inside the element +// in that field. The rest of the rules still apply. +// +// * If the struct has a field named XMLName of type xml.Name, +// Unmarshal records the element name in that field. +// +// * If the XMLName field has an associated tag string of the form +// "tag" or "namespace-URL tag", the XML element must have +// the given tag (and, optionally, name space) or else Unmarshal +// returns an error. +// +// * If the XML element has an attribute whose name matches a +// struct field of type string with tag "attr", Unmarshal records +// the attribute value in that field. +// +// * If the XML element contains character data, that data is +// accumulated in the first struct field that has tag "chardata". +// The struct field may have type []byte or string. +// If there is no such field, the character data is discarded. +// +// * If the XML element contains a sub-element whose name matches +// the prefix of a struct field tag formatted as "a>b>c", unmarshal +// will descend into the XML structure looking for elements with the +// given names, and will map the innermost elements to that struct field. +// A struct field tag starting with ">" is equivalent to one starting +// with the field name followed by ">". +// +// * If the XML element contains a sub-element whose name +// matches a struct field whose tag is neither "attr" nor "chardata", +// Unmarshal maps the sub-element to that struct field. +// Otherwise, if the struct has a field named Any, unmarshal +// maps the sub-element to that struct field. +// +// Unmarshal maps an XML element to a string or []byte by saving the +// concatenation of that element's character data in the string or []byte. +// +// Unmarshal maps an XML element to a slice by extending the length +// of the slice and mapping the element to the newly created value. +// +// Unmarshal maps an XML element to a bool by setting it to the boolean +// value represented by the string. +// +// Unmarshal maps an XML element to an integer or floating-point +// field by setting the field to the result of interpreting the string +// value in decimal. There is no check for overflow. +// +// Unmarshal maps an XML element to an xml.Name by recording the +// element name. +// +// Unmarshal maps an XML element to a pointer by setting the pointer +// to a freshly allocated value and then mapping the element to that value. +// +func Unmarshal(r io.Reader, val interface{}) os.Error { + v := reflect.NewValue(val) + if v.Kind() != reflect.Ptr { + return os.NewError("non-pointer passed to Unmarshal") + } + p := NewParser(r) + elem := v.Elem() + err := p.unmarshal(elem, nil) + if err != nil { + return err + } + return nil +} + +// An UnmarshalError represents an error in the unmarshalling process. +type UnmarshalError string + +func (e UnmarshalError) String() string { return string(e) } + +// A TagPathError represents an error in the unmarshalling process +// caused by the use of field tags with conflicting paths. +type TagPathError struct { + Struct reflect.Type + Field1, Tag1 string + Field2, Tag2 string +} + +func (e *TagPathError) String() string { + return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2) +} + +// The Parser's Unmarshal method is like xml.Unmarshal +// except that it can be passed a pointer to the initial start element, +// useful when a client reads some raw XML tokens itself +// but also defers to Unmarshal for some elements. +// Passing a nil start element indicates that Unmarshal should +// read the token stream to find the start element. +func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error { + v := reflect.NewValue(val) + if v.Kind() != reflect.Ptr { + return os.NewError("non-pointer passed to Unmarshal") + } + return p.unmarshal(v.Elem(), start) +} + +// fieldName strips invalid characters from an XML name +// to create a valid Go struct name. It also converts the +// name to lower case letters. +func fieldName(original string) string { + + var i int + //remove leading underscores + for i = 0; i < len(original) && original[i] == '_'; i++ { + } + + return strings.Map( + func(x int) int { + if x == '_' || unicode.IsDigit(x) || unicode.IsLetter(x) { + return unicode.ToLower(x) + } + return -1 + }, + original[i:]) +} + +// Unmarshal a single XML element into val. +func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { + // Find start element if we need it. + if start == nil { + for { + tok, err := p.Token() + if err != nil { + return err + } + if t, ok := tok.(StartElement); ok { + start = &t + break + } + } + } + + if pv := val; pv.Kind() == reflect.Ptr { + if pv.Pointer() == 0 { + zv := reflect.Zero(pv.Type().Elem()) + pv.Set(zv.Addr()) + val = zv + } else { + val = pv.Elem() + } + } + + var ( + data []byte + saveData reflect.Value + comment []byte + saveComment reflect.Value + saveXML reflect.Value + saveXMLIndex int + saveXMLData []byte + sv reflect.Value + styp reflect.Type + fieldPaths map[string]pathInfo + ) + + switch v := val; v.Kind() { + default: + return os.ErrorString("unknown type " + v.Type().String()) + + case reflect.Slice: + typ := v.Type() + if typ.Elem().Kind() == reflect.Uint8 { + // []byte + saveData = v + break + } + + // Slice of element values. + // Grow slice. + n := v.Len() + if n >= v.Cap() { + ncap := 2 * n + if ncap < 4 { + ncap = 4 + } + new := reflect.MakeSlice(typ, n, ncap) + reflect.Copy(new, v) + v.Set(new) + } + v.SetLen(n + 1) + + // Recur to read element into slice. + if err := p.unmarshal(v.Index(n), start); err != nil { + v.SetLen(n) + return err + } + return nil + + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + saveData = v + + case reflect.Struct: + if _, ok := v.Interface().(Name); ok { + v.Set(reflect.NewValue(start.Name)) + break + } + + sv = v + typ := sv.Type() + styp = typ + // Assign name. + if f, ok := typ.FieldByName("XMLName"); ok { + // Validate element name. + if f.Tag != "" { + tag := f.Tag + ns := "" + i := strings.LastIndex(tag, " ") + if i >= 0 { + ns, tag = tag[0:i], tag[i+1:] + } + if tag != start.Name.Local { + return UnmarshalError("expected element type <" + tag + "> but have <" + start.Name.Local + ">") + } + if ns != "" && ns != start.Name.Space { + e := "expected element <" + tag + "> in name space " + ns + " but have " + if start.Name.Space == "" { + e += "no name space" + } else { + e += start.Name.Space + } + return UnmarshalError(e) + } + } + + // Save + v := sv.FieldByIndex(f.Index) + if _, ok := v.Interface().(Name); !ok { + return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name") + } + v.Set(reflect.NewValue(start.Name)) + } + + // Assign attributes. + // Also, determine whether we need to save character data or comments. + for i, n := 0, typ.NumField(); i < n; i++ { + f := typ.Field(i) + switch f.Tag { + case "attr": + strv := sv.FieldByIndex(f.Index) + if strv.Kind() != reflect.String { + return UnmarshalError(sv.Type().String() + " field " + f.Name + " has attr tag but is not type string") + } + // Look for attribute. + val := "" + k := strings.ToLower(f.Name) + for _, a := range start.Attr { + if fieldName(a.Name.Local) == k { + val = a.Value + break + } + } + strv.SetString(val) + + case "comment": + if !saveComment.IsValid() { + saveComment = sv.FieldByIndex(f.Index) + } + + case "chardata": + if !saveData.IsValid() { + saveData = sv.FieldByIndex(f.Index) + } + + case "innerxml": + if !saveXML.IsValid() { + saveXML = sv.FieldByIndex(f.Index) + if p.saved == nil { + saveXMLIndex = 0 + p.saved = new(bytes.Buffer) + } else { + saveXMLIndex = p.savedOffset() + } + } + + default: + if strings.Contains(f.Tag, ">") { + if fieldPaths == nil { + fieldPaths = make(map[string]pathInfo) + } + path := strings.ToLower(f.Tag) + if strings.HasPrefix(f.Tag, ">") { + path = strings.ToLower(f.Name) + path + } + if strings.HasSuffix(f.Tag, ">") { + path = path[:len(path)-1] + } + err := addFieldPath(sv, fieldPaths, path, f.Index) + if err != nil { + return err + } + } + } + } + } + + // Find end element. + // Process sub-elements along the way. +Loop: + for { + var savedOffset int + if saveXML.IsValid() { + savedOffset = p.savedOffset() + } + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + // Sub-element. + // Look up by tag name. + if sv.IsValid() { + k := fieldName(t.Name.Local) + + if fieldPaths != nil { + if _, found := fieldPaths[k]; found { + if err := p.unmarshalPaths(sv, fieldPaths, k, &t); err != nil { + return err + } + continue Loop + } + } + + match := func(s string) bool { + // check if the name matches ignoring case + if strings.ToLower(s) != k { + return false + } + // now check that it's public + c, _ := utf8.DecodeRuneInString(s) + return unicode.IsUpper(c) + } + + f, found := styp.FieldByNameFunc(match) + if !found { // fall back to mop-up field named "Any" + f, found = styp.FieldByName("Any") + } + if found { + if err := p.unmarshal(sv.FieldByIndex(f.Index), &t); err != nil { + return err + } + continue Loop + } + } + // Not saving sub-element but still have to skip over it. + if err := p.Skip(); err != nil { + return err + } + + case EndElement: + if saveXML.IsValid() { + saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset] + if saveXMLIndex == 0 { + p.saved = nil + } + } + break Loop + + case CharData: + if saveData.IsValid() { + data = append(data, t...) + } + + case Comment: + if saveComment.IsValid() { + comment = append(comment, t...) + } + } + } + + var err os.Error + // Helper functions for integer and unsigned integer conversions + var itmp int64 + getInt64 := func() bool { + itmp, err = strconv.Atoi64(string(data)) + // TODO: should check sizes + return err == nil + } + var utmp uint64 + getUint64 := func() bool { + utmp, err = strconv.Atoui64(string(data)) + // TODO: check for overflow? + return err == nil + } + var ftmp float64 + getFloat64 := func() bool { + ftmp, err = strconv.Atof64(string(data)) + // TODO: check for overflow? + return err == nil + } + + // Save accumulated data and comments + switch t := saveData; t.Kind() { + case reflect.Invalid: + // Probably a comment, handled below + default: + return os.ErrorString("cannot happen: unknown type " + t.Type().String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if !getInt64() { + return err + } + t.SetInt(itmp) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if !getUint64() { + return err + } + t.SetUint(utmp) + case reflect.Float32, reflect.Float64: + if !getFloat64() { + return err + } + t.SetFloat(ftmp) + case reflect.Bool: + value, err := strconv.Atob(strings.TrimSpace(string(data))) + if err != nil { + return err + } + t.SetBool(value) + case reflect.String: + t.SetString(string(data)) + case reflect.Slice: + t.Set(reflect.NewValue(data)) + } + + switch t := saveComment; t.Kind() { + case reflect.String: + t.SetString(string(comment)) + case reflect.Slice: + t.Set(reflect.NewValue(comment)) + } + + switch t := saveXML; t.Kind() { + case reflect.String: + t.SetString(string(saveXMLData)) + case reflect.Slice: + t.Set(reflect.NewValue(saveXMLData)) + } + + return nil +} + +type pathInfo struct { + fieldIdx []int + complete bool +} + +// addFieldPath takes an element path such as "a>b>c" and fills the +// paths map with all paths leading to it ("a", "a>b", and "a>b>c"). +// It is okay for paths to share a common, shorter prefix but not ok +// for one path to itself be a prefix of another. +func addFieldPath(sv reflect.Value, paths map[string]pathInfo, path string, fieldIdx []int) os.Error { + if info, found := paths[path]; found { + return tagError(sv, info.fieldIdx, fieldIdx) + } + paths[path] = pathInfo{fieldIdx, true} + for { + i := strings.LastIndex(path, ">") + if i < 0 { + break + } + path = path[:i] + if info, found := paths[path]; found { + if info.complete { + return tagError(sv, info.fieldIdx, fieldIdx) + } + } else { + paths[path] = pathInfo{fieldIdx, false} + } + } + return nil + +} + +func tagError(sv reflect.Value, idx1 []int, idx2 []int) os.Error { + t := sv.Type() + f1 := t.FieldByIndex(idx1) + f2 := t.FieldByIndex(idx2) + return &TagPathError{t, f1.Name, f1.Tag, f2.Name, f2.Tag} +} + +// unmarshalPaths walks down an XML structure looking for +// wanted paths, and calls unmarshal on them. +func (p *Parser) unmarshalPaths(sv reflect.Value, paths map[string]pathInfo, path string, start *StartElement) os.Error { + if info, _ := paths[path]; info.complete { + return p.unmarshal(sv.FieldByIndex(info.fieldIdx), start) + } + for { + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + k := path + ">" + fieldName(t.Name.Local) + if _, found := paths[k]; found { + if err := p.unmarshalPaths(sv, paths, k, &t); err != nil { + return err + } + continue + } + if err := p.Skip(); err != nil { + return err + } + case EndElement: + return nil + } + } + panic("unreachable") +} + +// Have already read a start element. +// Read tokens until we find the end element. +// Token is taking care of making sure the +// end element matches the start element we saw. +func (p *Parser) Skip() os.Error { + for { + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + if err := p.Skip(); err != nil { + return err + } + case EndElement: + return nil + } + } + panic("unreachable") +} diff --git a/src/cmd/gofix/testdata/reflect.scan.go.in b/src/cmd/gofix/testdata/reflect.scan.go.in new file mode 100644 index 000000000..36271a8d4 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.scan.go.in @@ -0,0 +1,1084 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fmt + +import ( + "bytes" + "io" + "math" + "os" + "reflect" + "strconv" + "strings" + "unicode" + "utf8" +) + +// runeUnreader is the interface to something that can unread runes. +// If the object provided to Scan does not satisfy this interface, +// a local buffer will be used to back up the input, but its contents +// will be lost when Scan returns. +type runeUnreader interface { + UnreadRune() os.Error +} + +// ScanState represents the scanner state passed to custom scanners. +// Scanners may do rune-at-a-time scanning or ask the ScanState +// to discover the next space-delimited token. +type ScanState interface { + // ReadRune reads the next rune (Unicode code point) from the input. + // If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will + // return EOF after returning the first '\n' or when reading beyond + // the specified width. + ReadRune() (rune int, size int, err os.Error) + // UnreadRune causes the next call to ReadRune to return the same rune. + UnreadRune() os.Error + // Token skips space in the input if skipSpace is true, then returns the + // run of Unicode code points c satisfying f(c). If f is nil, + // !unicode.IsSpace(c) is used; that is, the token will hold non-space + // characters. Newlines are treated as space unless the scan operation + // is Scanln, Fscanln or Sscanln, in which case a newline is treated as + // EOF. The returned slice points to shared data that may be overwritten + // by the next call to Token, a call to a Scan function using the ScanState + // as input, or when the calling Scan method returns. + Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error) + // Width returns the value of the width option and whether it has been set. + // The unit is Unicode code points. + Width() (wid int, ok bool) + // Because ReadRune is implemented by the interface, Read should never be + // called by the scanning routines and a valid implementation of + // ScanState may choose always to return an error from Read. + Read(buf []byte) (n int, err os.Error) +} + +// Scanner is implemented by any value that has a Scan method, which scans +// the input for the representation of a value and stores the result in the +// receiver, which must be a pointer to be useful. The Scan method is called +// for any argument to Scan, Scanf, or Scanln that implements it. +type Scanner interface { + Scan(state ScanState, verb int) os.Error +} + +// Scan scans text read from standard input, storing successive +// space-separated values into successive arguments. Newlines count +// as space. It returns the number of items successfully scanned. +// If that is less than the number of arguments, err will report why. +func Scan(a ...interface{}) (n int, err os.Error) { + return Fscan(os.Stdin, a...) +} + +// Scanln is similar to Scan, but stops scanning at a newline and +// after the final item there must be a newline or EOF. +func Scanln(a ...interface{}) (n int, err os.Error) { + return Fscanln(os.Stdin, a...) +} + +// Scanf scans text read from standard input, storing successive +// space-separated values into successive arguments as determined by +// the format. It returns the number of items successfully scanned. +func Scanf(format string, a ...interface{}) (n int, err os.Error) { + return Fscanf(os.Stdin, format, a...) +} + +// Sscan scans the argument string, storing successive space-separated +// values into successive arguments. Newlines count as space. It +// returns the number of items successfully scanned. If that is less +// than the number of arguments, err will report why. +func Sscan(str string, a ...interface{}) (n int, err os.Error) { + return Fscan(strings.NewReader(str), a...) +} + +// Sscanln is similar to Sscan, but stops scanning at a newline and +// after the final item there must be a newline or EOF. +func Sscanln(str string, a ...interface{}) (n int, err os.Error) { + return Fscanln(strings.NewReader(str), a...) +} + +// Sscanf scans the argument string, storing successive space-separated +// values into successive arguments as determined by the format. It +// returns the number of items successfully parsed. +func Sscanf(str string, format string, a ...interface{}) (n int, err os.Error) { + return Fscanf(strings.NewReader(str), format, a...) +} + +// Fscan scans text read from r, storing successive space-separated +// values into successive arguments. Newlines count as space. It +// returns the number of items successfully scanned. If that is less +// than the number of arguments, err will report why. +func Fscan(r io.Reader, a ...interface{}) (n int, err os.Error) { + s, old := newScanState(r, true, false) + n, err = s.doScan(a) + s.free(old) + return +} + +// Fscanln is similar to Fscan, but stops scanning at a newline and +// after the final item there must be a newline or EOF. +func Fscanln(r io.Reader, a ...interface{}) (n int, err os.Error) { + s, old := newScanState(r, false, true) + n, err = s.doScan(a) + s.free(old) + return +} + +// Fscanf scans text read from r, storing successive space-separated +// values into successive arguments as determined by the format. It +// returns the number of items successfully parsed. +func Fscanf(r io.Reader, format string, a ...interface{}) (n int, err os.Error) { + s, old := newScanState(r, false, false) + n, err = s.doScanf(format, a) + s.free(old) + return +} + +// scanError represents an error generated by the scanning software. +// It's used as a unique signature to identify such errors when recovering. +type scanError struct { + err os.Error +} + +const eof = -1 + +// ss is the internal implementation of ScanState. +type ss struct { + rr io.RuneReader // where to read input + buf bytes.Buffer // token accumulator + peekRune int // one-rune lookahead + prevRune int // last rune returned by ReadRune + count int // runes consumed so far. + atEOF bool // already read EOF + ssave +} + +// ssave holds the parts of ss that need to be +// saved and restored on recursive scans. +type ssave struct { + validSave bool // is or was a part of an actual ss. + nlIsEnd bool // whether newline terminates scan + nlIsSpace bool // whether newline counts as white space + fieldLimit int // max value of ss.count for this field; fieldLimit <= limit + limit int // max value of ss.count. + maxWid int // width of this field. +} + +// The Read method is only in ScanState so that ScanState +// satisfies io.Reader. It will never be called when used as +// intended, so there is no need to make it actually work. +func (s *ss) Read(buf []byte) (n int, err os.Error) { + return 0, os.ErrorString("ScanState's Read should not be called. Use ReadRune") +} + +func (s *ss) ReadRune() (rune int, size int, err os.Error) { + if s.peekRune >= 0 { + s.count++ + rune = s.peekRune + size = utf8.RuneLen(rune) + s.prevRune = rune + s.peekRune = -1 + return + } + if s.atEOF || s.nlIsEnd && s.prevRune == '\n' || s.count >= s.fieldLimit { + err = os.EOF + return + } + + rune, size, err = s.rr.ReadRune() + if err == nil { + s.count++ + s.prevRune = rune + } else if err == os.EOF { + s.atEOF = true + } + return +} + +func (s *ss) Width() (wid int, ok bool) { + if s.maxWid == hugeWid { + return 0, false + } + return s.maxWid, true +} + +// The public method returns an error; this private one panics. +// If getRune reaches EOF, the return value is EOF (-1). +func (s *ss) getRune() (rune int) { + rune, _, err := s.ReadRune() + if err != nil { + if err == os.EOF { + return eof + } + s.error(err) + } + return +} + +// mustReadRune turns os.EOF into a panic(io.ErrUnexpectedEOF). +// It is called in cases such as string scanning where an EOF is a +// syntax error. +func (s *ss) mustReadRune() (rune int) { + rune = s.getRune() + if rune == eof { + s.error(io.ErrUnexpectedEOF) + } + return +} + +func (s *ss) UnreadRune() os.Error { + if u, ok := s.rr.(runeUnreader); ok { + u.UnreadRune() + } else { + s.peekRune = s.prevRune + } + s.count-- + return nil +} + +func (s *ss) error(err os.Error) { + panic(scanError{err}) +} + +func (s *ss) errorString(err string) { + panic(scanError{os.ErrorString(err)}) +} + +func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) { + defer func() { + if e := recover(); e != nil { + if se, ok := e.(scanError); ok { + err = se.err + } else { + panic(e) + } + } + }() + if f == nil { + f = notSpace + } + s.buf.Reset() + tok = s.token(skipSpace, f) + return +} + +// notSpace is the default scanning function used in Token. +func notSpace(r int) bool { + return !unicode.IsSpace(r) +} + +// readRune is a structure to enable reading UTF-8 encoded code points +// from an io.Reader. It is used if the Reader given to the scanner does +// not already implement io.RuneReader. +type readRune struct { + reader io.Reader + buf [utf8.UTFMax]byte // used only inside ReadRune + pending int // number of bytes in pendBuf; only >0 for bad UTF-8 + pendBuf [utf8.UTFMax]byte // bytes left over +} + +// readByte returns the next byte from the input, which may be +// left over from a previous read if the UTF-8 was ill-formed. +func (r *readRune) readByte() (b byte, err os.Error) { + if r.pending > 0 { + b = r.pendBuf[0] + copy(r.pendBuf[0:], r.pendBuf[1:]) + r.pending-- + return + } + _, err = r.reader.Read(r.pendBuf[0:1]) + return r.pendBuf[0], err +} + +// unread saves the bytes for the next read. +func (r *readRune) unread(buf []byte) { + copy(r.pendBuf[r.pending:], buf) + r.pending += len(buf) +} + +// ReadRune returns the next UTF-8 encoded code point from the +// io.Reader inside r. +func (r *readRune) ReadRune() (rune int, size int, err os.Error) { + r.buf[0], err = r.readByte() + if err != nil { + return 0, 0, err + } + if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case + rune = int(r.buf[0]) + return + } + var n int + for n = 1; !utf8.FullRune(r.buf[0:n]); n++ { + r.buf[n], err = r.readByte() + if err != nil { + if err == os.EOF { + err = nil + break + } + return + } + } + rune, size = utf8.DecodeRune(r.buf[0:n]) + if size < n { // an error + r.unread(r.buf[size:n]) + } + return +} + + +var ssFree = newCache(func() interface{} { return new(ss) }) + +// Allocate a new ss struct or grab a cached one. +func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) { + // If the reader is a *ss, then we've got a recursive + // call to Scan, so re-use the scan state. + s, ok := r.(*ss) + if ok { + old = s.ssave + s.limit = s.fieldLimit + s.nlIsEnd = nlIsEnd || s.nlIsEnd + s.nlIsSpace = nlIsSpace + return + } + + s = ssFree.get().(*ss) + if rr, ok := r.(io.RuneReader); ok { + s.rr = rr + } else { + s.rr = &readRune{reader: r} + } + s.nlIsSpace = nlIsSpace + s.nlIsEnd = nlIsEnd + s.prevRune = -1 + s.peekRune = -1 + s.atEOF = false + s.limit = hugeWid + s.fieldLimit = hugeWid + s.maxWid = hugeWid + s.validSave = true + return +} + +// Save used ss structs in ssFree; avoid an allocation per invocation. +func (s *ss) free(old ssave) { + // If it was used recursively, just restore the old state. + if old.validSave { + s.ssave = old + return + } + // Don't hold on to ss structs with large buffers. + if cap(s.buf.Bytes()) > 1024 { + return + } + s.buf.Reset() + s.rr = nil + ssFree.put(s) +} + +// skipSpace skips spaces and maybe newlines. +func (s *ss) skipSpace(stopAtNewline bool) { + for { + rune := s.getRune() + if rune == eof { + return + } + if rune == '\n' { + if stopAtNewline { + break + } + if s.nlIsSpace { + continue + } + s.errorString("unexpected newline") + return + } + if !unicode.IsSpace(rune) { + s.UnreadRune() + break + } + } +} + + +// token returns the next space-delimited string from the input. It +// skips white space. For Scanln, it stops at newlines. For Scan, +// newlines are treated as spaces. +func (s *ss) token(skipSpace bool, f func(int) bool) []byte { + if skipSpace { + s.skipSpace(false) + } + // read until white space or newline + for { + rune := s.getRune() + if rune == eof { + break + } + if !f(rune) { + s.UnreadRune() + break + } + s.buf.WriteRune(rune) + } + return s.buf.Bytes() +} + +// typeError indicates that the type of the operand did not match the format +func (s *ss) typeError(field interface{}, expected string) { + s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String()) +} + +var complexError = os.ErrorString("syntax error scanning complex number") +var boolError = os.ErrorString("syntax error scanning boolean") + +// consume reads the next rune in the input and reports whether it is in the ok string. +// If accept is true, it puts the character into the input token. +func (s *ss) consume(ok string, accept bool) bool { + rune := s.getRune() + if rune == eof { + return false + } + if strings.IndexRune(ok, rune) >= 0 { + if accept { + s.buf.WriteRune(rune) + } + return true + } + if rune != eof && accept { + s.UnreadRune() + } + return false +} + +// peek reports whether the next character is in the ok string, without consuming it. +func (s *ss) peek(ok string) bool { + rune := s.getRune() + if rune != eof { + s.UnreadRune() + } + return strings.IndexRune(ok, rune) >= 0 +} + +// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the +// buffer and returns true. Otherwise it return false. +func (s *ss) accept(ok string) bool { + return s.consume(ok, true) +} + +// okVerb verifies that the verb is present in the list, setting s.err appropriately if not. +func (s *ss) okVerb(verb int, okVerbs, typ string) bool { + for _, v := range okVerbs { + if v == verb { + return true + } + } + s.errorString("bad verb %" + string(verb) + " for " + typ) + return false +} + +// scanBool returns the value of the boolean represented by the next token. +func (s *ss) scanBool(verb int) bool { + if !s.okVerb(verb, "tv", "boolean") { + return false + } + // Syntax-checking a boolean is annoying. We're not fastidious about case. + switch s.mustReadRune() { + case '0': + return false + case '1': + return true + case 't', 'T': + if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) { + s.error(boolError) + } + return true + case 'f', 'F': + if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) { + s.error(boolError) + } + return false + } + return false +} + +// Numerical elements +const ( + binaryDigits = "01" + octalDigits = "01234567" + decimalDigits = "0123456789" + hexadecimalDigits = "0123456789aAbBcCdDeEfF" + sign = "+-" + period = "." + exponent = "eEp" +) + +// getBase returns the numeric base represented by the verb and its digit string. +func (s *ss) getBase(verb int) (base int, digits string) { + s.okVerb(verb, "bdoUxXv", "integer") // sets s.err + base = 10 + digits = decimalDigits + switch verb { + case 'b': + base = 2 + digits = binaryDigits + case 'o': + base = 8 + digits = octalDigits + case 'x', 'X', 'U': + base = 16 + digits = hexadecimalDigits + } + return +} + +// scanNumber returns the numerical string with specified digits starting here. +func (s *ss) scanNumber(digits string, haveDigits bool) string { + if !haveDigits && !s.accept(digits) { + s.errorString("expected integer") + } + for s.accept(digits) { + } + return s.buf.String() +} + +// scanRune returns the next rune value in the input. +func (s *ss) scanRune(bitSize int) int64 { + rune := int64(s.mustReadRune()) + n := uint(bitSize) + x := (rune << (64 - n)) >> (64 - n) + if x != rune { + s.errorString("overflow on character value " + string(rune)) + } + return rune +} + +// scanBasePrefix reports whether the integer begins with a 0 or 0x, +// and returns the base, digit string, and whether a zero was found. +// It is called only if the verb is %v. +func (s *ss) scanBasePrefix() (base int, digits string, found bool) { + if !s.peek("0") { + return 10, decimalDigits, false + } + s.accept("0") + found = true // We've put a digit into the token buffer. + // Special cases for '0' && '0x' + base, digits = 8, octalDigits + if s.peek("xX") { + s.consume("xX", false) + base, digits = 16, hexadecimalDigits + } + return +} + +// scanInt returns the value of the integer represented by the next +// token, checking for overflow. Any error is stored in s.err. +func (s *ss) scanInt(verb int, bitSize int) int64 { + if verb == 'c' { + return s.scanRune(bitSize) + } + s.skipSpace(false) + base, digits := s.getBase(verb) + haveDigits := false + if verb == 'U' { + if !s.consume("U", false) || !s.consume("+", false) { + s.errorString("bad unicode format ") + } + } else { + s.accept(sign) // If there's a sign, it will be left in the token buffer. + if verb == 'v' { + base, digits, haveDigits = s.scanBasePrefix() + } + } + tok := s.scanNumber(digits, haveDigits) + i, err := strconv.Btoi64(tok, base) + if err != nil { + s.error(err) + } + n := uint(bitSize) + x := (i << (64 - n)) >> (64 - n) + if x != i { + s.errorString("integer overflow on token " + tok) + } + return i +} + +// scanUint returns the value of the unsigned integer represented +// by the next token, checking for overflow. Any error is stored in s.err. +func (s *ss) scanUint(verb int, bitSize int) uint64 { + if verb == 'c' { + return uint64(s.scanRune(bitSize)) + } + s.skipSpace(false) + base, digits := s.getBase(verb) + haveDigits := false + if verb == 'U' { + if !s.consume("U", false) || !s.consume("+", false) { + s.errorString("bad unicode format ") + } + } else if verb == 'v' { + base, digits, haveDigits = s.scanBasePrefix() + } + tok := s.scanNumber(digits, haveDigits) + i, err := strconv.Btoui64(tok, base) + if err != nil { + s.error(err) + } + n := uint(bitSize) + x := (i << (64 - n)) >> (64 - n) + if x != i { + s.errorString("unsigned integer overflow on token " + tok) + } + return i +} + +// floatToken returns the floating-point number starting here, no longer than swid +// if the width is specified. It's not rigorous about syntax because it doesn't check that +// we have at least some digits, but Atof will do that. +func (s *ss) floatToken() string { + s.buf.Reset() + // NaN? + if s.accept("nN") && s.accept("aA") && s.accept("nN") { + return s.buf.String() + } + // leading sign? + s.accept(sign) + // Inf? + if s.accept("iI") && s.accept("nN") && s.accept("fF") { + return s.buf.String() + } + // digits? + for s.accept(decimalDigits) { + } + // decimal point? + if s.accept(period) { + // fraction? + for s.accept(decimalDigits) { + } + } + // exponent? + if s.accept(exponent) { + // leading sign? + s.accept(sign) + // digits? + for s.accept(decimalDigits) { + } + } + return s.buf.String() +} + +// complexTokens returns the real and imaginary parts of the complex number starting here. +// The number might be parenthesized and has the format (N+Ni) where N is a floating-point +// number and there are no spaces within. +func (s *ss) complexTokens() (real, imag string) { + // TODO: accept N and Ni independently? + parens := s.accept("(") + real = s.floatToken() + s.buf.Reset() + // Must now have a sign. + if !s.accept("+-") { + s.error(complexError) + } + // Sign is now in buffer + imagSign := s.buf.String() + imag = s.floatToken() + if !s.accept("i") { + s.error(complexError) + } + if parens && !s.accept(")") { + s.error(complexError) + } + return real, imagSign + imag +} + +// convertFloat converts the string to a float64value. +func (s *ss) convertFloat(str string, n int) float64 { + if p := strings.Index(str, "p"); p >= 0 { + // Atof doesn't handle power-of-2 exponents, + // but they're easy to evaluate. + f, err := strconv.AtofN(str[:p], n) + if err != nil { + // Put full string into error. + if e, ok := err.(*strconv.NumError); ok { + e.Num = str + } + s.error(err) + } + n, err := strconv.Atoi(str[p+1:]) + if err != nil { + // Put full string into error. + if e, ok := err.(*strconv.NumError); ok { + e.Num = str + } + s.error(err) + } + return math.Ldexp(f, n) + } + f, err := strconv.AtofN(str, n) + if err != nil { + s.error(err) + } + return f +} + +// convertComplex converts the next token to a complex128 value. +// The atof argument is a type-specific reader for the underlying type. +// If we're reading complex64, atof will parse float32s and convert them +// to float64's to avoid reproducing this code for each complex type. +func (s *ss) scanComplex(verb int, n int) complex128 { + if !s.okVerb(verb, floatVerbs, "complex") { + return 0 + } + s.skipSpace(false) + sreal, simag := s.complexTokens() + real := s.convertFloat(sreal, n/2) + imag := s.convertFloat(simag, n/2) + return complex(real, imag) +} + +// convertString returns the string represented by the next input characters. +// The format of the input is determined by the verb. +func (s *ss) convertString(verb int) (str string) { + if !s.okVerb(verb, "svqx", "string") { + return "" + } + s.skipSpace(false) + switch verb { + case 'q': + str = s.quotedString() + case 'x': + str = s.hexString() + default: + str = string(s.token(true, notSpace)) // %s and %v just return the next word + } + // Empty strings other than with %q are not OK. + if len(str) == 0 && verb != 'q' && s.maxWid > 0 { + s.errorString("Scan: no data for string") + } + return +} + +// quotedString returns the double- or back-quoted string represented by the next input characters. +func (s *ss) quotedString() string { + quote := s.mustReadRune() + switch quote { + case '`': + // Back-quoted: Anything goes until EOF or back quote. + for { + rune := s.mustReadRune() + if rune == quote { + break + } + s.buf.WriteRune(rune) + } + return s.buf.String() + case '"': + // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes. + s.buf.WriteRune(quote) + for { + rune := s.mustReadRune() + s.buf.WriteRune(rune) + if rune == '\\' { + // In a legal backslash escape, no matter how long, only the character + // immediately after the escape can itself be a backslash or quote. + // Thus we only need to protect the first character after the backslash. + rune := s.mustReadRune() + s.buf.WriteRune(rune) + } else if rune == '"' { + break + } + } + result, err := strconv.Unquote(s.buf.String()) + if err != nil { + s.error(err) + } + return result + default: + s.errorString("expected quoted string") + } + return "" +} + +// hexDigit returns the value of the hexadecimal digit +func (s *ss) hexDigit(digit int) int { + switch digit { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return digit - '0' + case 'a', 'b', 'c', 'd', 'e', 'f': + return 10 + digit - 'a' + case 'A', 'B', 'C', 'D', 'E', 'F': + return 10 + digit - 'A' + } + s.errorString("Scan: illegal hex digit") + return 0 +} + +// hexByte returns the next hex-encoded (two-character) byte from the input. +// There must be either two hexadecimal digits or a space character in the input. +func (s *ss) hexByte() (b byte, ok bool) { + rune1 := s.getRune() + if rune1 == eof { + return + } + if unicode.IsSpace(rune1) { + s.UnreadRune() + return + } + rune2 := s.mustReadRune() + return byte(s.hexDigit(rune1)<<4 | s.hexDigit(rune2)), true +} + +// hexString returns the space-delimited hexpair-encoded string. +func (s *ss) hexString() string { + for { + b, ok := s.hexByte() + if !ok { + break + } + s.buf.WriteByte(b) + } + if s.buf.Len() == 0 { + s.errorString("Scan: no hex data for %x string") + return "" + } + return s.buf.String() +} + +const floatVerbs = "beEfFgGv" + +const hugeWid = 1 << 30 + +// scanOne scans a single value, deriving the scanner from the type of the argument. +func (s *ss) scanOne(verb int, field interface{}) { + s.buf.Reset() + var err os.Error + // If the parameter has its own Scan method, use that. + if v, ok := field.(Scanner); ok { + err = v.Scan(s, verb) + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + s.error(err) + } + return + } + switch v := field.(type) { + case *bool: + *v = s.scanBool(verb) + case *complex64: + *v = complex64(s.scanComplex(verb, 64)) + case *complex128: + *v = s.scanComplex(verb, 128) + case *int: + *v = int(s.scanInt(verb, intBits)) + case *int8: + *v = int8(s.scanInt(verb, 8)) + case *int16: + *v = int16(s.scanInt(verb, 16)) + case *int32: + *v = int32(s.scanInt(verb, 32)) + case *int64: + *v = s.scanInt(verb, 64) + case *uint: + *v = uint(s.scanUint(verb, intBits)) + case *uint8: + *v = uint8(s.scanUint(verb, 8)) + case *uint16: + *v = uint16(s.scanUint(verb, 16)) + case *uint32: + *v = uint32(s.scanUint(verb, 32)) + case *uint64: + *v = s.scanUint(verb, 64) + case *uintptr: + *v = uintptr(s.scanUint(verb, uintptrBits)) + // Floats are tricky because you want to scan in the precision of the result, not + // scan in high precision and convert, in order to preserve the correct error condition. + case *float32: + if s.okVerb(verb, floatVerbs, "float32") { + s.skipSpace(false) + *v = float32(s.convertFloat(s.floatToken(), 32)) + } + case *float64: + if s.okVerb(verb, floatVerbs, "float64") { + s.skipSpace(false) + *v = s.convertFloat(s.floatToken(), 64) + } + case *string: + *v = s.convertString(verb) + case *[]byte: + // We scan to string and convert so we get a copy of the data. + // If we scanned to bytes, the slice would point at the buffer. + *v = []byte(s.convertString(verb)) + default: + val := reflect.NewValue(v) + ptr, ok := val.(*reflect.PtrValue) + if !ok { + s.errorString("Scan: type not a pointer: " + val.Type().String()) + return + } + switch v := ptr.Elem().(type) { + case *reflect.BoolValue: + v.Set(s.scanBool(verb)) + case *reflect.IntValue: + v.Set(s.scanInt(verb, v.Type().Bits())) + case *reflect.UintValue: + v.Set(s.scanUint(verb, v.Type().Bits())) + case *reflect.StringValue: + v.Set(s.convertString(verb)) + case *reflect.SliceValue: + // For now, can only handle (renamed) []byte. + typ := v.Type().(*reflect.SliceType) + if typ.Elem().Kind() != reflect.Uint8 { + goto CantHandle + } + str := s.convertString(verb) + v.Set(reflect.MakeSlice(typ, len(str), len(str))) + for i := 0; i < len(str); i++ { + v.Elem(i).(*reflect.UintValue).Set(uint64(str[i])) + } + case *reflect.FloatValue: + s.skipSpace(false) + v.Set(s.convertFloat(s.floatToken(), v.Type().Bits())) + case *reflect.ComplexValue: + v.Set(s.scanComplex(verb, v.Type().Bits())) + default: + CantHandle: + s.errorString("Scan: can't handle type: " + val.Type().String()) + } + } +} + +// errorHandler turns local panics into error returns. EOFs are benign. +func errorHandler(errp *os.Error) { + if e := recover(); e != nil { + if se, ok := e.(scanError); ok { // catch local error + if se.err != os.EOF { + *errp = se.err + } + } else { + panic(e) + } + } +} + +// doScan does the real work for scanning without a format string. +func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) { + defer errorHandler(&err) + for _, field := range a { + s.scanOne('v', field) + numProcessed++ + } + // Check for newline if required. + if !s.nlIsSpace { + for { + rune := s.getRune() + if rune == '\n' || rune == eof { + break + } + if !unicode.IsSpace(rune) { + s.errorString("Scan: expected newline") + break + } + } + } + return +} + +// advance determines whether the next characters in the input match +// those of the format. It returns the number of bytes (sic) consumed +// in the format. Newlines included, all runs of space characters in +// either input or format behave as a single space. This routine also +// handles the %% case. If the return value is zero, either format +// starts with a % (with no following %) or the input is empty. +// If it is negative, the input did not match the string. +func (s *ss) advance(format string) (i int) { + for i < len(format) { + fmtc, w := utf8.DecodeRuneInString(format[i:]) + if fmtc == '%' { + // %% acts like a real percent + nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty + if nextc != '%' { + return + } + i += w // skip the first % + } + sawSpace := false + for unicode.IsSpace(fmtc) && i < len(format) { + sawSpace = true + i += w + fmtc, w = utf8.DecodeRuneInString(format[i:]) + } + if sawSpace { + // There was space in the format, so there should be space (EOF) + // in the input. + inputc := s.getRune() + if inputc == eof { + return + } + if !unicode.IsSpace(inputc) { + // Space in format but not in input: error + s.errorString("expected space in input to match format") + } + s.skipSpace(true) + continue + } + inputc := s.mustReadRune() + if fmtc != inputc { + s.UnreadRune() + return -1 + } + i += w + } + return +} + +// doScanf does the real work when scanning with a format string. +// At the moment, it handles only pointers to basic types. +func (s *ss) doScanf(format string, a []interface{}) (numProcessed int, err os.Error) { + defer errorHandler(&err) + end := len(format) - 1 + // We process one item per non-trivial format + for i := 0; i <= end; { + w := s.advance(format[i:]) + if w > 0 { + i += w + continue + } + // Either we failed to advance, we have a percent character, or we ran out of input. + if format[i] != '%' { + // Can't advance format. Why not? + if w < 0 { + s.errorString("input does not match format") + } + // Otherwise at EOF; "too many operands" error handled below + break + } + i++ // % is one byte + + // do we have 20 (width)? + var widPresent bool + s.maxWid, widPresent, i = parsenum(format, i, end) + if !widPresent { + s.maxWid = hugeWid + } + s.fieldLimit = s.limit + if f := s.count + s.maxWid; f < s.fieldLimit { + s.fieldLimit = f + } + + c, w := utf8.DecodeRuneInString(format[i:]) + i += w + + if numProcessed >= len(a) { // out of operands + s.errorString("too few operands for format %" + format[i-w:]) + break + } + field := a[numProcessed] + + s.scanOne(c, field) + numProcessed++ + s.fieldLimit = s.limit + } + if numProcessed < len(a) { + s.errorString("too many operands") + } + return +} diff --git a/src/cmd/gofix/testdata/reflect.scan.go.out b/src/cmd/gofix/testdata/reflect.scan.go.out new file mode 100644 index 000000000..b1b3975e2 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.scan.go.out @@ -0,0 +1,1084 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fmt + +import ( + "bytes" + "io" + "math" + "os" + "reflect" + "strconv" + "strings" + "unicode" + "utf8" +) + +// runeUnreader is the interface to something that can unread runes. +// If the object provided to Scan does not satisfy this interface, +// a local buffer will be used to back up the input, but its contents +// will be lost when Scan returns. +type runeUnreader interface { + UnreadRune() os.Error +} + +// ScanState represents the scanner state passed to custom scanners. +// Scanners may do rune-at-a-time scanning or ask the ScanState +// to discover the next space-delimited token. +type ScanState interface { + // ReadRune reads the next rune (Unicode code point) from the input. + // If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will + // return EOF after returning the first '\n' or when reading beyond + // the specified width. + ReadRune() (rune int, size int, err os.Error) + // UnreadRune causes the next call to ReadRune to return the same rune. + UnreadRune() os.Error + // Token skips space in the input if skipSpace is true, then returns the + // run of Unicode code points c satisfying f(c). If f is nil, + // !unicode.IsSpace(c) is used; that is, the token will hold non-space + // characters. Newlines are treated as space unless the scan operation + // is Scanln, Fscanln or Sscanln, in which case a newline is treated as + // EOF. The returned slice points to shared data that may be overwritten + // by the next call to Token, a call to a Scan function using the ScanState + // as input, or when the calling Scan method returns. + Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error) + // Width returns the value of the width option and whether it has been set. + // The unit is Unicode code points. + Width() (wid int, ok bool) + // Because ReadRune is implemented by the interface, Read should never be + // called by the scanning routines and a valid implementation of + // ScanState may choose always to return an error from Read. + Read(buf []byte) (n int, err os.Error) +} + +// Scanner is implemented by any value that has a Scan method, which scans +// the input for the representation of a value and stores the result in the +// receiver, which must be a pointer to be useful. The Scan method is called +// for any argument to Scan, Scanf, or Scanln that implements it. +type Scanner interface { + Scan(state ScanState, verb int) os.Error +} + +// Scan scans text read from standard input, storing successive +// space-separated values into successive arguments. Newlines count +// as space. It returns the number of items successfully scanned. +// If that is less than the number of arguments, err will report why. +func Scan(a ...interface{}) (n int, err os.Error) { + return Fscan(os.Stdin, a...) +} + +// Scanln is similar to Scan, but stops scanning at a newline and +// after the final item there must be a newline or EOF. +func Scanln(a ...interface{}) (n int, err os.Error) { + return Fscanln(os.Stdin, a...) +} + +// Scanf scans text read from standard input, storing successive +// space-separated values into successive arguments as determined by +// the format. It returns the number of items successfully scanned. +func Scanf(format string, a ...interface{}) (n int, err os.Error) { + return Fscanf(os.Stdin, format, a...) +} + +// Sscan scans the argument string, storing successive space-separated +// values into successive arguments. Newlines count as space. It +// returns the number of items successfully scanned. If that is less +// than the number of arguments, err will report why. +func Sscan(str string, a ...interface{}) (n int, err os.Error) { + return Fscan(strings.NewReader(str), a...) +} + +// Sscanln is similar to Sscan, but stops scanning at a newline and +// after the final item there must be a newline or EOF. +func Sscanln(str string, a ...interface{}) (n int, err os.Error) { + return Fscanln(strings.NewReader(str), a...) +} + +// Sscanf scans the argument string, storing successive space-separated +// values into successive arguments as determined by the format. It +// returns the number of items successfully parsed. +func Sscanf(str string, format string, a ...interface{}) (n int, err os.Error) { + return Fscanf(strings.NewReader(str), format, a...) +} + +// Fscan scans text read from r, storing successive space-separated +// values into successive arguments. Newlines count as space. It +// returns the number of items successfully scanned. If that is less +// than the number of arguments, err will report why. +func Fscan(r io.Reader, a ...interface{}) (n int, err os.Error) { + s, old := newScanState(r, true, false) + n, err = s.doScan(a) + s.free(old) + return +} + +// Fscanln is similar to Fscan, but stops scanning at a newline and +// after the final item there must be a newline or EOF. +func Fscanln(r io.Reader, a ...interface{}) (n int, err os.Error) { + s, old := newScanState(r, false, true) + n, err = s.doScan(a) + s.free(old) + return +} + +// Fscanf scans text read from r, storing successive space-separated +// values into successive arguments as determined by the format. It +// returns the number of items successfully parsed. +func Fscanf(r io.Reader, format string, a ...interface{}) (n int, err os.Error) { + s, old := newScanState(r, false, false) + n, err = s.doScanf(format, a) + s.free(old) + return +} + +// scanError represents an error generated by the scanning software. +// It's used as a unique signature to identify such errors when recovering. +type scanError struct { + err os.Error +} + +const eof = -1 + +// ss is the internal implementation of ScanState. +type ss struct { + rr io.RuneReader // where to read input + buf bytes.Buffer // token accumulator + peekRune int // one-rune lookahead + prevRune int // last rune returned by ReadRune + count int // runes consumed so far. + atEOF bool // already read EOF + ssave +} + +// ssave holds the parts of ss that need to be +// saved and restored on recursive scans. +type ssave struct { + validSave bool // is or was a part of an actual ss. + nlIsEnd bool // whether newline terminates scan + nlIsSpace bool // whether newline counts as white space + fieldLimit int // max value of ss.count for this field; fieldLimit <= limit + limit int // max value of ss.count. + maxWid int // width of this field. +} + +// The Read method is only in ScanState so that ScanState +// satisfies io.Reader. It will never be called when used as +// intended, so there is no need to make it actually work. +func (s *ss) Read(buf []byte) (n int, err os.Error) { + return 0, os.ErrorString("ScanState's Read should not be called. Use ReadRune") +} + +func (s *ss) ReadRune() (rune int, size int, err os.Error) { + if s.peekRune >= 0 { + s.count++ + rune = s.peekRune + size = utf8.RuneLen(rune) + s.prevRune = rune + s.peekRune = -1 + return + } + if s.atEOF || s.nlIsEnd && s.prevRune == '\n' || s.count >= s.fieldLimit { + err = os.EOF + return + } + + rune, size, err = s.rr.ReadRune() + if err == nil { + s.count++ + s.prevRune = rune + } else if err == os.EOF { + s.atEOF = true + } + return +} + +func (s *ss) Width() (wid int, ok bool) { + if s.maxWid == hugeWid { + return 0, false + } + return s.maxWid, true +} + +// The public method returns an error; this private one panics. +// If getRune reaches EOF, the return value is EOF (-1). +func (s *ss) getRune() (rune int) { + rune, _, err := s.ReadRune() + if err != nil { + if err == os.EOF { + return eof + } + s.error(err) + } + return +} + +// mustReadRune turns os.EOF into a panic(io.ErrUnexpectedEOF). +// It is called in cases such as string scanning where an EOF is a +// syntax error. +func (s *ss) mustReadRune() (rune int) { + rune = s.getRune() + if rune == eof { + s.error(io.ErrUnexpectedEOF) + } + return +} + +func (s *ss) UnreadRune() os.Error { + if u, ok := s.rr.(runeUnreader); ok { + u.UnreadRune() + } else { + s.peekRune = s.prevRune + } + s.count-- + return nil +} + +func (s *ss) error(err os.Error) { + panic(scanError{err}) +} + +func (s *ss) errorString(err string) { + panic(scanError{os.ErrorString(err)}) +} + +func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) { + defer func() { + if e := recover(); e != nil { + if se, ok := e.(scanError); ok { + err = se.err + } else { + panic(e) + } + } + }() + if f == nil { + f = notSpace + } + s.buf.Reset() + tok = s.token(skipSpace, f) + return +} + +// notSpace is the default scanning function used in Token. +func notSpace(r int) bool { + return !unicode.IsSpace(r) +} + +// readRune is a structure to enable reading UTF-8 encoded code points +// from an io.Reader. It is used if the Reader given to the scanner does +// not already implement io.RuneReader. +type readRune struct { + reader io.Reader + buf [utf8.UTFMax]byte // used only inside ReadRune + pending int // number of bytes in pendBuf; only >0 for bad UTF-8 + pendBuf [utf8.UTFMax]byte // bytes left over +} + +// readByte returns the next byte from the input, which may be +// left over from a previous read if the UTF-8 was ill-formed. +func (r *readRune) readByte() (b byte, err os.Error) { + if r.pending > 0 { + b = r.pendBuf[0] + copy(r.pendBuf[0:], r.pendBuf[1:]) + r.pending-- + return + } + _, err = r.reader.Read(r.pendBuf[0:1]) + return r.pendBuf[0], err +} + +// unread saves the bytes for the next read. +func (r *readRune) unread(buf []byte) { + copy(r.pendBuf[r.pending:], buf) + r.pending += len(buf) +} + +// ReadRune returns the next UTF-8 encoded code point from the +// io.Reader inside r. +func (r *readRune) ReadRune() (rune int, size int, err os.Error) { + r.buf[0], err = r.readByte() + if err != nil { + return 0, 0, err + } + if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case + rune = int(r.buf[0]) + return + } + var n int + for n = 1; !utf8.FullRune(r.buf[0:n]); n++ { + r.buf[n], err = r.readByte() + if err != nil { + if err == os.EOF { + err = nil + break + } + return + } + } + rune, size = utf8.DecodeRune(r.buf[0:n]) + if size < n { // an error + r.unread(r.buf[size:n]) + } + return +} + + +var ssFree = newCache(func() interface{} { return new(ss) }) + +// Allocate a new ss struct or grab a cached one. +func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) { + // If the reader is a *ss, then we've got a recursive + // call to Scan, so re-use the scan state. + s, ok := r.(*ss) + if ok { + old = s.ssave + s.limit = s.fieldLimit + s.nlIsEnd = nlIsEnd || s.nlIsEnd + s.nlIsSpace = nlIsSpace + return + } + + s = ssFree.get().(*ss) + if rr, ok := r.(io.RuneReader); ok { + s.rr = rr + } else { + s.rr = &readRune{reader: r} + } + s.nlIsSpace = nlIsSpace + s.nlIsEnd = nlIsEnd + s.prevRune = -1 + s.peekRune = -1 + s.atEOF = false + s.limit = hugeWid + s.fieldLimit = hugeWid + s.maxWid = hugeWid + s.validSave = true + return +} + +// Save used ss structs in ssFree; avoid an allocation per invocation. +func (s *ss) free(old ssave) { + // If it was used recursively, just restore the old state. + if old.validSave { + s.ssave = old + return + } + // Don't hold on to ss structs with large buffers. + if cap(s.buf.Bytes()) > 1024 { + return + } + s.buf.Reset() + s.rr = nil + ssFree.put(s) +} + +// skipSpace skips spaces and maybe newlines. +func (s *ss) skipSpace(stopAtNewline bool) { + for { + rune := s.getRune() + if rune == eof { + return + } + if rune == '\n' { + if stopAtNewline { + break + } + if s.nlIsSpace { + continue + } + s.errorString("unexpected newline") + return + } + if !unicode.IsSpace(rune) { + s.UnreadRune() + break + } + } +} + + +// token returns the next space-delimited string from the input. It +// skips white space. For Scanln, it stops at newlines. For Scan, +// newlines are treated as spaces. +func (s *ss) token(skipSpace bool, f func(int) bool) []byte { + if skipSpace { + s.skipSpace(false) + } + // read until white space or newline + for { + rune := s.getRune() + if rune == eof { + break + } + if !f(rune) { + s.UnreadRune() + break + } + s.buf.WriteRune(rune) + } + return s.buf.Bytes() +} + +// typeError indicates that the type of the operand did not match the format +func (s *ss) typeError(field interface{}, expected string) { + s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String()) +} + +var complexError = os.ErrorString("syntax error scanning complex number") +var boolError = os.ErrorString("syntax error scanning boolean") + +// consume reads the next rune in the input and reports whether it is in the ok string. +// If accept is true, it puts the character into the input token. +func (s *ss) consume(ok string, accept bool) bool { + rune := s.getRune() + if rune == eof { + return false + } + if strings.IndexRune(ok, rune) >= 0 { + if accept { + s.buf.WriteRune(rune) + } + return true + } + if rune != eof && accept { + s.UnreadRune() + } + return false +} + +// peek reports whether the next character is in the ok string, without consuming it. +func (s *ss) peek(ok string) bool { + rune := s.getRune() + if rune != eof { + s.UnreadRune() + } + return strings.IndexRune(ok, rune) >= 0 +} + +// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the +// buffer and returns true. Otherwise it return false. +func (s *ss) accept(ok string) bool { + return s.consume(ok, true) +} + +// okVerb verifies that the verb is present in the list, setting s.err appropriately if not. +func (s *ss) okVerb(verb int, okVerbs, typ string) bool { + for _, v := range okVerbs { + if v == verb { + return true + } + } + s.errorString("bad verb %" + string(verb) + " for " + typ) + return false +} + +// scanBool returns the value of the boolean represented by the next token. +func (s *ss) scanBool(verb int) bool { + if !s.okVerb(verb, "tv", "boolean") { + return false + } + // Syntax-checking a boolean is annoying. We're not fastidious about case. + switch s.mustReadRune() { + case '0': + return false + case '1': + return true + case 't', 'T': + if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) { + s.error(boolError) + } + return true + case 'f', 'F': + if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) { + s.error(boolError) + } + return false + } + return false +} + +// Numerical elements +const ( + binaryDigits = "01" + octalDigits = "01234567" + decimalDigits = "0123456789" + hexadecimalDigits = "0123456789aAbBcCdDeEfF" + sign = "+-" + period = "." + exponent = "eEp" +) + +// getBase returns the numeric base represented by the verb and its digit string. +func (s *ss) getBase(verb int) (base int, digits string) { + s.okVerb(verb, "bdoUxXv", "integer") // sets s.err + base = 10 + digits = decimalDigits + switch verb { + case 'b': + base = 2 + digits = binaryDigits + case 'o': + base = 8 + digits = octalDigits + case 'x', 'X', 'U': + base = 16 + digits = hexadecimalDigits + } + return +} + +// scanNumber returns the numerical string with specified digits starting here. +func (s *ss) scanNumber(digits string, haveDigits bool) string { + if !haveDigits && !s.accept(digits) { + s.errorString("expected integer") + } + for s.accept(digits) { + } + return s.buf.String() +} + +// scanRune returns the next rune value in the input. +func (s *ss) scanRune(bitSize int) int64 { + rune := int64(s.mustReadRune()) + n := uint(bitSize) + x := (rune << (64 - n)) >> (64 - n) + if x != rune { + s.errorString("overflow on character value " + string(rune)) + } + return rune +} + +// scanBasePrefix reports whether the integer begins with a 0 or 0x, +// and returns the base, digit string, and whether a zero was found. +// It is called only if the verb is %v. +func (s *ss) scanBasePrefix() (base int, digits string, found bool) { + if !s.peek("0") { + return 10, decimalDigits, false + } + s.accept("0") + found = true // We've put a digit into the token buffer. + // Special cases for '0' && '0x' + base, digits = 8, octalDigits + if s.peek("xX") { + s.consume("xX", false) + base, digits = 16, hexadecimalDigits + } + return +} + +// scanInt returns the value of the integer represented by the next +// token, checking for overflow. Any error is stored in s.err. +func (s *ss) scanInt(verb int, bitSize int) int64 { + if verb == 'c' { + return s.scanRune(bitSize) + } + s.skipSpace(false) + base, digits := s.getBase(verb) + haveDigits := false + if verb == 'U' { + if !s.consume("U", false) || !s.consume("+", false) { + s.errorString("bad unicode format ") + } + } else { + s.accept(sign) // If there's a sign, it will be left in the token buffer. + if verb == 'v' { + base, digits, haveDigits = s.scanBasePrefix() + } + } + tok := s.scanNumber(digits, haveDigits) + i, err := strconv.Btoi64(tok, base) + if err != nil { + s.error(err) + } + n := uint(bitSize) + x := (i << (64 - n)) >> (64 - n) + if x != i { + s.errorString("integer overflow on token " + tok) + } + return i +} + +// scanUint returns the value of the unsigned integer represented +// by the next token, checking for overflow. Any error is stored in s.err. +func (s *ss) scanUint(verb int, bitSize int) uint64 { + if verb == 'c' { + return uint64(s.scanRune(bitSize)) + } + s.skipSpace(false) + base, digits := s.getBase(verb) + haveDigits := false + if verb == 'U' { + if !s.consume("U", false) || !s.consume("+", false) { + s.errorString("bad unicode format ") + } + } else if verb == 'v' { + base, digits, haveDigits = s.scanBasePrefix() + } + tok := s.scanNumber(digits, haveDigits) + i, err := strconv.Btoui64(tok, base) + if err != nil { + s.error(err) + } + n := uint(bitSize) + x := (i << (64 - n)) >> (64 - n) + if x != i { + s.errorString("unsigned integer overflow on token " + tok) + } + return i +} + +// floatToken returns the floating-point number starting here, no longer than swid +// if the width is specified. It's not rigorous about syntax because it doesn't check that +// we have at least some digits, but Atof will do that. +func (s *ss) floatToken() string { + s.buf.Reset() + // NaN? + if s.accept("nN") && s.accept("aA") && s.accept("nN") { + return s.buf.String() + } + // leading sign? + s.accept(sign) + // Inf? + if s.accept("iI") && s.accept("nN") && s.accept("fF") { + return s.buf.String() + } + // digits? + for s.accept(decimalDigits) { + } + // decimal point? + if s.accept(period) { + // fraction? + for s.accept(decimalDigits) { + } + } + // exponent? + if s.accept(exponent) { + // leading sign? + s.accept(sign) + // digits? + for s.accept(decimalDigits) { + } + } + return s.buf.String() +} + +// complexTokens returns the real and imaginary parts of the complex number starting here. +// The number might be parenthesized and has the format (N+Ni) where N is a floating-point +// number and there are no spaces within. +func (s *ss) complexTokens() (real, imag string) { + // TODO: accept N and Ni independently? + parens := s.accept("(") + real = s.floatToken() + s.buf.Reset() + // Must now have a sign. + if !s.accept("+-") { + s.error(complexError) + } + // Sign is now in buffer + imagSign := s.buf.String() + imag = s.floatToken() + if !s.accept("i") { + s.error(complexError) + } + if parens && !s.accept(")") { + s.error(complexError) + } + return real, imagSign + imag +} + +// convertFloat converts the string to a float64value. +func (s *ss) convertFloat(str string, n int) float64 { + if p := strings.Index(str, "p"); p >= 0 { + // Atof doesn't handle power-of-2 exponents, + // but they're easy to evaluate. + f, err := strconv.AtofN(str[:p], n) + if err != nil { + // Put full string into error. + if e, ok := err.(*strconv.NumError); ok { + e.Num = str + } + s.error(err) + } + n, err := strconv.Atoi(str[p+1:]) + if err != nil { + // Put full string into error. + if e, ok := err.(*strconv.NumError); ok { + e.Num = str + } + s.error(err) + } + return math.Ldexp(f, n) + } + f, err := strconv.AtofN(str, n) + if err != nil { + s.error(err) + } + return f +} + +// convertComplex converts the next token to a complex128 value. +// The atof argument is a type-specific reader for the underlying type. +// If we're reading complex64, atof will parse float32s and convert them +// to float64's to avoid reproducing this code for each complex type. +func (s *ss) scanComplex(verb int, n int) complex128 { + if !s.okVerb(verb, floatVerbs, "complex") { + return 0 + } + s.skipSpace(false) + sreal, simag := s.complexTokens() + real := s.convertFloat(sreal, n/2) + imag := s.convertFloat(simag, n/2) + return complex(real, imag) +} + +// convertString returns the string represented by the next input characters. +// The format of the input is determined by the verb. +func (s *ss) convertString(verb int) (str string) { + if !s.okVerb(verb, "svqx", "string") { + return "" + } + s.skipSpace(false) + switch verb { + case 'q': + str = s.quotedString() + case 'x': + str = s.hexString() + default: + str = string(s.token(true, notSpace)) // %s and %v just return the next word + } + // Empty strings other than with %q are not OK. + if len(str) == 0 && verb != 'q' && s.maxWid > 0 { + s.errorString("Scan: no data for string") + } + return +} + +// quotedString returns the double- or back-quoted string represented by the next input characters. +func (s *ss) quotedString() string { + quote := s.mustReadRune() + switch quote { + case '`': + // Back-quoted: Anything goes until EOF or back quote. + for { + rune := s.mustReadRune() + if rune == quote { + break + } + s.buf.WriteRune(rune) + } + return s.buf.String() + case '"': + // Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes. + s.buf.WriteRune(quote) + for { + rune := s.mustReadRune() + s.buf.WriteRune(rune) + if rune == '\\' { + // In a legal backslash escape, no matter how long, only the character + // immediately after the escape can itself be a backslash or quote. + // Thus we only need to protect the first character after the backslash. + rune := s.mustReadRune() + s.buf.WriteRune(rune) + } else if rune == '"' { + break + } + } + result, err := strconv.Unquote(s.buf.String()) + if err != nil { + s.error(err) + } + return result + default: + s.errorString("expected quoted string") + } + return "" +} + +// hexDigit returns the value of the hexadecimal digit +func (s *ss) hexDigit(digit int) int { + switch digit { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return digit - '0' + case 'a', 'b', 'c', 'd', 'e', 'f': + return 10 + digit - 'a' + case 'A', 'B', 'C', 'D', 'E', 'F': + return 10 + digit - 'A' + } + s.errorString("Scan: illegal hex digit") + return 0 +} + +// hexByte returns the next hex-encoded (two-character) byte from the input. +// There must be either two hexadecimal digits or a space character in the input. +func (s *ss) hexByte() (b byte, ok bool) { + rune1 := s.getRune() + if rune1 == eof { + return + } + if unicode.IsSpace(rune1) { + s.UnreadRune() + return + } + rune2 := s.mustReadRune() + return byte(s.hexDigit(rune1)<<4 | s.hexDigit(rune2)), true +} + +// hexString returns the space-delimited hexpair-encoded string. +func (s *ss) hexString() string { + for { + b, ok := s.hexByte() + if !ok { + break + } + s.buf.WriteByte(b) + } + if s.buf.Len() == 0 { + s.errorString("Scan: no hex data for %x string") + return "" + } + return s.buf.String() +} + +const floatVerbs = "beEfFgGv" + +const hugeWid = 1 << 30 + +// scanOne scans a single value, deriving the scanner from the type of the argument. +func (s *ss) scanOne(verb int, field interface{}) { + s.buf.Reset() + var err os.Error + // If the parameter has its own Scan method, use that. + if v, ok := field.(Scanner); ok { + err = v.Scan(s, verb) + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + s.error(err) + } + return + } + switch v := field.(type) { + case *bool: + *v = s.scanBool(verb) + case *complex64: + *v = complex64(s.scanComplex(verb, 64)) + case *complex128: + *v = s.scanComplex(verb, 128) + case *int: + *v = int(s.scanInt(verb, intBits)) + case *int8: + *v = int8(s.scanInt(verb, 8)) + case *int16: + *v = int16(s.scanInt(verb, 16)) + case *int32: + *v = int32(s.scanInt(verb, 32)) + case *int64: + *v = s.scanInt(verb, 64) + case *uint: + *v = uint(s.scanUint(verb, intBits)) + case *uint8: + *v = uint8(s.scanUint(verb, 8)) + case *uint16: + *v = uint16(s.scanUint(verb, 16)) + case *uint32: + *v = uint32(s.scanUint(verb, 32)) + case *uint64: + *v = s.scanUint(verb, 64) + case *uintptr: + *v = uintptr(s.scanUint(verb, uintptrBits)) + // Floats are tricky because you want to scan in the precision of the result, not + // scan in high precision and convert, in order to preserve the correct error condition. + case *float32: + if s.okVerb(verb, floatVerbs, "float32") { + s.skipSpace(false) + *v = float32(s.convertFloat(s.floatToken(), 32)) + } + case *float64: + if s.okVerb(verb, floatVerbs, "float64") { + s.skipSpace(false) + *v = s.convertFloat(s.floatToken(), 64) + } + case *string: + *v = s.convertString(verb) + case *[]byte: + // We scan to string and convert so we get a copy of the data. + // If we scanned to bytes, the slice would point at the buffer. + *v = []byte(s.convertString(verb)) + default: + val := reflect.NewValue(v) + ptr := val + if ptr.Kind() != reflect.Ptr { + s.errorString("Scan: type not a pointer: " + val.Type().String()) + return + } + switch v := ptr.Elem(); v.Kind() { + case reflect.Bool: + v.SetBool(s.scanBool(verb)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.SetInt(s.scanInt(verb, v.Type().Bits())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + v.SetUint(s.scanUint(verb, v.Type().Bits())) + case reflect.String: + v.SetString(s.convertString(verb)) + case reflect.Slice: + // For now, can only handle (renamed) []byte. + typ := v.Type() + if typ.Elem().Kind() != reflect.Uint8 { + goto CantHandle + } + str := s.convertString(verb) + v.Set(reflect.MakeSlice(typ, len(str), len(str))) + for i := 0; i < len(str); i++ { + v.Index(i).SetUint(uint64(str[i])) + } + case reflect.Float32, reflect.Float64: + s.skipSpace(false) + v.SetFloat(s.convertFloat(s.floatToken(), v.Type().Bits())) + case reflect.Complex64, reflect.Complex128: + v.SetComplex(s.scanComplex(verb, v.Type().Bits())) + default: + CantHandle: + s.errorString("Scan: can't handle type: " + val.Type().String()) + } + } +} + +// errorHandler turns local panics into error returns. EOFs are benign. +func errorHandler(errp *os.Error) { + if e := recover(); e != nil { + if se, ok := e.(scanError); ok { // catch local error + if se.err != os.EOF { + *errp = se.err + } + } else { + panic(e) + } + } +} + +// doScan does the real work for scanning without a format string. +func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) { + defer errorHandler(&err) + for _, field := range a { + s.scanOne('v', field) + numProcessed++ + } + // Check for newline if required. + if !s.nlIsSpace { + for { + rune := s.getRune() + if rune == '\n' || rune == eof { + break + } + if !unicode.IsSpace(rune) { + s.errorString("Scan: expected newline") + break + } + } + } + return +} + +// advance determines whether the next characters in the input match +// those of the format. It returns the number of bytes (sic) consumed +// in the format. Newlines included, all runs of space characters in +// either input or format behave as a single space. This routine also +// handles the %% case. If the return value is zero, either format +// starts with a % (with no following %) or the input is empty. +// If it is negative, the input did not match the string. +func (s *ss) advance(format string) (i int) { + for i < len(format) { + fmtc, w := utf8.DecodeRuneInString(format[i:]) + if fmtc == '%' { + // %% acts like a real percent + nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty + if nextc != '%' { + return + } + i += w // skip the first % + } + sawSpace := false + for unicode.IsSpace(fmtc) && i < len(format) { + sawSpace = true + i += w + fmtc, w = utf8.DecodeRuneInString(format[i:]) + } + if sawSpace { + // There was space in the format, so there should be space (EOF) + // in the input. + inputc := s.getRune() + if inputc == eof { + return + } + if !unicode.IsSpace(inputc) { + // Space in format but not in input: error + s.errorString("expected space in input to match format") + } + s.skipSpace(true) + continue + } + inputc := s.mustReadRune() + if fmtc != inputc { + s.UnreadRune() + return -1 + } + i += w + } + return +} + +// doScanf does the real work when scanning with a format string. +// At the moment, it handles only pointers to basic types. +func (s *ss) doScanf(format string, a []interface{}) (numProcessed int, err os.Error) { + defer errorHandler(&err) + end := len(format) - 1 + // We process one item per non-trivial format + for i := 0; i <= end; { + w := s.advance(format[i:]) + if w > 0 { + i += w + continue + } + // Either we failed to advance, we have a percent character, or we ran out of input. + if format[i] != '%' { + // Can't advance format. Why not? + if w < 0 { + s.errorString("input does not match format") + } + // Otherwise at EOF; "too many operands" error handled below + break + } + i++ // % is one byte + + // do we have 20 (width)? + var widPresent bool + s.maxWid, widPresent, i = parsenum(format, i, end) + if !widPresent { + s.maxWid = hugeWid + } + s.fieldLimit = s.limit + if f := s.count + s.maxWid; f < s.fieldLimit { + s.fieldLimit = f + } + + c, w := utf8.DecodeRuneInString(format[i:]) + i += w + + if numProcessed >= len(a) { // out of operands + s.errorString("too few operands for format %" + format[i-w:]) + break + } + field := a[numProcessed] + + s.scanOne(c, field) + numProcessed++ + s.fieldLimit = s.limit + } + if numProcessed < len(a) { + s.errorString("too many operands") + } + return +} diff --git a/src/cmd/gofix/testdata/reflect.script.go.in b/src/cmd/gofix/testdata/reflect.script.go.in new file mode 100644 index 000000000..b341b1f89 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.script.go.in @@ -0,0 +1,359 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package aids in the testing of code that uses channels. +package script + +import ( + "fmt" + "os" + "rand" + "reflect" + "strings" +) + +// An Event is an element in a partially ordered set that either sends a value +// to a channel or expects a value from a channel. +type Event struct { + name string + occurred bool + predecessors []*Event + action action +} + +type action interface { + // getSend returns nil if the action is not a send action. + getSend() sendAction + // getRecv returns nil if the action is not a receive action. + getRecv() recvAction + // getChannel returns the channel that the action operates on. + getChannel() interface{} +} + +type recvAction interface { + recvMatch(interface{}) bool +} + +type sendAction interface { + send() +} + +// isReady returns true if all the predecessors of an Event have occurred. +func (e Event) isReady() bool { + for _, predecessor := range e.predecessors { + if !predecessor.occurred { + return false + } + } + + return true +} + +// A Recv action reads a value from a channel and uses reflect.DeepMatch to +// compare it with an expected value. +type Recv struct { + Channel interface{} + Expected interface{} +} + +func (r Recv) getRecv() recvAction { return r } + +func (Recv) getSend() sendAction { return nil } + +func (r Recv) getChannel() interface{} { return r.Channel } + +func (r Recv) recvMatch(chanEvent interface{}) bool { + c, ok := chanEvent.(channelRecv) + if !ok || c.channel != r.Channel { + return false + } + + return reflect.DeepEqual(c.value, r.Expected) +} + +// A RecvMatch action reads a value from a channel and calls a function to +// determine if the value matches. +type RecvMatch struct { + Channel interface{} + Match func(interface{}) bool +} + +func (r RecvMatch) getRecv() recvAction { return r } + +func (RecvMatch) getSend() sendAction { return nil } + +func (r RecvMatch) getChannel() interface{} { return r.Channel } + +func (r RecvMatch) recvMatch(chanEvent interface{}) bool { + c, ok := chanEvent.(channelRecv) + if !ok || c.channel != r.Channel { + return false + } + + return r.Match(c.value) +} + +// A Closed action matches if the given channel is closed. The closing is +// treated as an event, not a state, thus Closed will only match once for a +// given channel. +type Closed struct { + Channel interface{} +} + +func (r Closed) getRecv() recvAction { return r } + +func (Closed) getSend() sendAction { return nil } + +func (r Closed) getChannel() interface{} { return r.Channel } + +func (r Closed) recvMatch(chanEvent interface{}) bool { + c, ok := chanEvent.(channelClosed) + if !ok || c.channel != r.Channel { + return false + } + + return true +} + +// A Send action sends a value to a channel. The value must match the +// type of the channel exactly unless the channel if of type chan interface{}. +type Send struct { + Channel interface{} + Value interface{} +} + +func (Send) getRecv() recvAction { return nil } + +func (s Send) getSend() sendAction { return s } + +func (s Send) getChannel() interface{} { return s.Channel } + +type empty struct { + x interface{} +} + +func newEmptyInterface(e empty) reflect.Value { + return reflect.NewValue(e).(*reflect.StructValue).Field(0) +} + +func (s Send) send() { + // With reflect.ChanValue.Send, we must match the types exactly. So, if + // s.Channel is a chan interface{} we convert s.Value to an interface{} + // first. + c := reflect.NewValue(s.Channel).(*reflect.ChanValue) + var v reflect.Value + if iface, ok := c.Type().(*reflect.ChanType).Elem().(*reflect.InterfaceType); ok && iface.NumMethod() == 0 { + v = newEmptyInterface(empty{s.Value}) + } else { + v = reflect.NewValue(s.Value) + } + c.Send(v) +} + +// A Close action closes the given channel. +type Close struct { + Channel interface{} +} + +func (Close) getRecv() recvAction { return nil } + +func (s Close) getSend() sendAction { return s } + +func (s Close) getChannel() interface{} { return s.Channel } + +func (s Close) send() { reflect.NewValue(s.Channel).(*reflect.ChanValue).Close() } + +// A ReceivedUnexpected error results if no active Events match a value +// received from a channel. +type ReceivedUnexpected struct { + Value interface{} + ready []*Event +} + +func (r ReceivedUnexpected) String() string { + names := make([]string, len(r.ready)) + for i, v := range r.ready { + names[i] = v.name + } + return fmt.Sprintf("received unexpected value on one of the channels: %#v. Runnable events: %s", r.Value, strings.Join(names, ", ")) +} + +// A SetupError results if there is a error with the configuration of a set of +// Events. +type SetupError string + +func (s SetupError) String() string { return string(s) } + +func NewEvent(name string, predecessors []*Event, action action) *Event { + e := &Event{name, false, predecessors, action} + return e +} + +// Given a set of Events, Perform repeatedly iterates over the set and finds the +// subset of ready Events (that is, all of their predecessors have +// occurred). From that subset, it pseudo-randomly selects an Event to perform. +// If the Event is a send event, the send occurs and Perform recalculates the ready +// set. If the event is a receive event, Perform waits for a value from any of the +// channels that are contained in any of the events. That value is then matched +// against the ready events. The first event that matches is considered to +// have occurred and Perform recalculates the ready set. +// +// Perform continues this until all Events have occurred. +// +// Note that uncollected goroutines may still be reading from any of the +// channels read from after Perform returns. +// +// For example, consider the problem of testing a function that reads values on +// one channel and echos them to two output channels. To test this we would +// create three events: a send event and two receive events. Each of the +// receive events must list the send event as a predecessor but there is no +// ordering between the receive events. +// +// send := NewEvent("send", nil, Send{c, 1}) +// recv1 := NewEvent("recv 1", []*Event{send}, Recv{c, 1}) +// recv2 := NewEvent("recv 2", []*Event{send}, Recv{c, 1}) +// Perform(0, []*Event{send, recv1, recv2}) +// +// At first, only the send event would be in the ready set and thus Perform will +// send a value to the input channel. Now the two receive events are ready and +// Perform will match each of them against the values read from the output channels. +// +// It would be invalid to list one of the receive events as a predecessor of +// the other. At each receive step, all the receive channels are considered, +// thus Perform may see a value from a channel that is not in the current ready +// set and fail. +func Perform(seed int64, events []*Event) (err os.Error) { + r := rand.New(rand.NewSource(seed)) + + channels, err := getChannels(events) + if err != nil { + return + } + multiplex := make(chan interface{}) + for _, channel := range channels { + go recvValues(multiplex, channel) + } + +Outer: + for { + ready, err := readyEvents(events) + if err != nil { + return err + } + + if len(ready) == 0 { + // All events occurred. + break + } + + event := ready[r.Intn(len(ready))] + if send := event.action.getSend(); send != nil { + send.send() + event.occurred = true + continue + } + + v := <-multiplex + for _, event := range ready { + if recv := event.action.getRecv(); recv != nil && recv.recvMatch(v) { + event.occurred = true + continue Outer + } + } + + return ReceivedUnexpected{v, ready} + } + + return nil +} + +// getChannels returns all the channels listed in any receive events. +func getChannels(events []*Event) ([]interface{}, os.Error) { + channels := make([]interface{}, len(events)) + + j := 0 + for _, event := range events { + if recv := event.action.getRecv(); recv == nil { + continue + } + c := event.action.getChannel() + if _, ok := reflect.NewValue(c).(*reflect.ChanValue); !ok { + return nil, SetupError("one of the channel values is not a channel") + } + + duplicate := false + for _, other := range channels[0:j] { + if c == other { + duplicate = true + break + } + } + + if !duplicate { + channels[j] = c + j++ + } + } + + return channels[0:j], nil +} + +// recvValues is a multiplexing helper function. It reads values from the given +// channel repeatedly, wrapping them up as either a channelRecv or +// channelClosed structure, and forwards them to the multiplex channel. +func recvValues(multiplex chan<- interface{}, channel interface{}) { + c := reflect.NewValue(channel).(*reflect.ChanValue) + + for { + v, ok := c.Recv() + if !ok { + multiplex <- channelClosed{channel} + return + } + + multiplex <- channelRecv{channel, v.Interface()} + } +} + +type channelClosed struct { + channel interface{} +} + +type channelRecv struct { + channel interface{} + value interface{} +} + +// readyEvents returns the subset of events that are ready. +func readyEvents(events []*Event) ([]*Event, os.Error) { + ready := make([]*Event, len(events)) + + j := 0 + eventsWaiting := false + for _, event := range events { + if event.occurred { + continue + } + + eventsWaiting = true + if event.isReady() { + ready[j] = event + j++ + } + } + + if j == 0 && eventsWaiting { + names := make([]string, len(events)) + for _, event := range events { + if event.occurred { + continue + } + names[j] = event.name + } + + return nil, SetupError("dependency cycle in events. These events are waiting to run but cannot: " + strings.Join(names, ", ")) + } + + return ready[0:j], nil +} diff --git a/src/cmd/gofix/testdata/reflect.script.go.out b/src/cmd/gofix/testdata/reflect.script.go.out new file mode 100644 index 000000000..b18018497 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.script.go.out @@ -0,0 +1,359 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package aids in the testing of code that uses channels. +package script + +import ( + "fmt" + "os" + "rand" + "reflect" + "strings" +) + +// An Event is an element in a partially ordered set that either sends a value +// to a channel or expects a value from a channel. +type Event struct { + name string + occurred bool + predecessors []*Event + action action +} + +type action interface { + // getSend returns nil if the action is not a send action. + getSend() sendAction + // getRecv returns nil if the action is not a receive action. + getRecv() recvAction + // getChannel returns the channel that the action operates on. + getChannel() interface{} +} + +type recvAction interface { + recvMatch(interface{}) bool +} + +type sendAction interface { + send() +} + +// isReady returns true if all the predecessors of an Event have occurred. +func (e Event) isReady() bool { + for _, predecessor := range e.predecessors { + if !predecessor.occurred { + return false + } + } + + return true +} + +// A Recv action reads a value from a channel and uses reflect.DeepMatch to +// compare it with an expected value. +type Recv struct { + Channel interface{} + Expected interface{} +} + +func (r Recv) getRecv() recvAction { return r } + +func (Recv) getSend() sendAction { return nil } + +func (r Recv) getChannel() interface{} { return r.Channel } + +func (r Recv) recvMatch(chanEvent interface{}) bool { + c, ok := chanEvent.(channelRecv) + if !ok || c.channel != r.Channel { + return false + } + + return reflect.DeepEqual(c.value, r.Expected) +} + +// A RecvMatch action reads a value from a channel and calls a function to +// determine if the value matches. +type RecvMatch struct { + Channel interface{} + Match func(interface{}) bool +} + +func (r RecvMatch) getRecv() recvAction { return r } + +func (RecvMatch) getSend() sendAction { return nil } + +func (r RecvMatch) getChannel() interface{} { return r.Channel } + +func (r RecvMatch) recvMatch(chanEvent interface{}) bool { + c, ok := chanEvent.(channelRecv) + if !ok || c.channel != r.Channel { + return false + } + + return r.Match(c.value) +} + +// A Closed action matches if the given channel is closed. The closing is +// treated as an event, not a state, thus Closed will only match once for a +// given channel. +type Closed struct { + Channel interface{} +} + +func (r Closed) getRecv() recvAction { return r } + +func (Closed) getSend() sendAction { return nil } + +func (r Closed) getChannel() interface{} { return r.Channel } + +func (r Closed) recvMatch(chanEvent interface{}) bool { + c, ok := chanEvent.(channelClosed) + if !ok || c.channel != r.Channel { + return false + } + + return true +} + +// A Send action sends a value to a channel. The value must match the +// type of the channel exactly unless the channel if of type chan interface{}. +type Send struct { + Channel interface{} + Value interface{} +} + +func (Send) getRecv() recvAction { return nil } + +func (s Send) getSend() sendAction { return s } + +func (s Send) getChannel() interface{} { return s.Channel } + +type empty struct { + x interface{} +} + +func newEmptyInterface(e empty) reflect.Value { + return reflect.NewValue(e).Field(0) +} + +func (s Send) send() { + // With reflect.ChanValue.Send, we must match the types exactly. So, if + // s.Channel is a chan interface{} we convert s.Value to an interface{} + // first. + c := reflect.NewValue(s.Channel) + var v reflect.Value + if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 { + v = newEmptyInterface(empty{s.Value}) + } else { + v = reflect.NewValue(s.Value) + } + c.Send(v) +} + +// A Close action closes the given channel. +type Close struct { + Channel interface{} +} + +func (Close) getRecv() recvAction { return nil } + +func (s Close) getSend() sendAction { return s } + +func (s Close) getChannel() interface{} { return s.Channel } + +func (s Close) send() { reflect.NewValue(s.Channel).Close() } + +// A ReceivedUnexpected error results if no active Events match a value +// received from a channel. +type ReceivedUnexpected struct { + Value interface{} + ready []*Event +} + +func (r ReceivedUnexpected) String() string { + names := make([]string, len(r.ready)) + for i, v := range r.ready { + names[i] = v.name + } + return fmt.Sprintf("received unexpected value on one of the channels: %#v. Runnable events: %s", r.Value, strings.Join(names, ", ")) +} + +// A SetupError results if there is a error with the configuration of a set of +// Events. +type SetupError string + +func (s SetupError) String() string { return string(s) } + +func NewEvent(name string, predecessors []*Event, action action) *Event { + e := &Event{name, false, predecessors, action} + return e +} + +// Given a set of Events, Perform repeatedly iterates over the set and finds the +// subset of ready Events (that is, all of their predecessors have +// occurred). From that subset, it pseudo-randomly selects an Event to perform. +// If the Event is a send event, the send occurs and Perform recalculates the ready +// set. If the event is a receive event, Perform waits for a value from any of the +// channels that are contained in any of the events. That value is then matched +// against the ready events. The first event that matches is considered to +// have occurred and Perform recalculates the ready set. +// +// Perform continues this until all Events have occurred. +// +// Note that uncollected goroutines may still be reading from any of the +// channels read from after Perform returns. +// +// For example, consider the problem of testing a function that reads values on +// one channel and echos them to two output channels. To test this we would +// create three events: a send event and two receive events. Each of the +// receive events must list the send event as a predecessor but there is no +// ordering between the receive events. +// +// send := NewEvent("send", nil, Send{c, 1}) +// recv1 := NewEvent("recv 1", []*Event{send}, Recv{c, 1}) +// recv2 := NewEvent("recv 2", []*Event{send}, Recv{c, 1}) +// Perform(0, []*Event{send, recv1, recv2}) +// +// At first, only the send event would be in the ready set and thus Perform will +// send a value to the input channel. Now the two receive events are ready and +// Perform will match each of them against the values read from the output channels. +// +// It would be invalid to list one of the receive events as a predecessor of +// the other. At each receive step, all the receive channels are considered, +// thus Perform may see a value from a channel that is not in the current ready +// set and fail. +func Perform(seed int64, events []*Event) (err os.Error) { + r := rand.New(rand.NewSource(seed)) + + channels, err := getChannels(events) + if err != nil { + return + } + multiplex := make(chan interface{}) + for _, channel := range channels { + go recvValues(multiplex, channel) + } + +Outer: + for { + ready, err := readyEvents(events) + if err != nil { + return err + } + + if len(ready) == 0 { + // All events occurred. + break + } + + event := ready[r.Intn(len(ready))] + if send := event.action.getSend(); send != nil { + send.send() + event.occurred = true + continue + } + + v := <-multiplex + for _, event := range ready { + if recv := event.action.getRecv(); recv != nil && recv.recvMatch(v) { + event.occurred = true + continue Outer + } + } + + return ReceivedUnexpected{v, ready} + } + + return nil +} + +// getChannels returns all the channels listed in any receive events. +func getChannels(events []*Event) ([]interface{}, os.Error) { + channels := make([]interface{}, len(events)) + + j := 0 + for _, event := range events { + if recv := event.action.getRecv(); recv == nil { + continue + } + c := event.action.getChannel() + if reflect.NewValue(c).Kind() != reflect.Chan { + return nil, SetupError("one of the channel values is not a channel") + } + + duplicate := false + for _, other := range channels[0:j] { + if c == other { + duplicate = true + break + } + } + + if !duplicate { + channels[j] = c + j++ + } + } + + return channels[0:j], nil +} + +// recvValues is a multiplexing helper function. It reads values from the given +// channel repeatedly, wrapping them up as either a channelRecv or +// channelClosed structure, and forwards them to the multiplex channel. +func recvValues(multiplex chan<- interface{}, channel interface{}) { + c := reflect.NewValue(channel) + + for { + v, ok := c.Recv() + if !ok { + multiplex <- channelClosed{channel} + return + } + + multiplex <- channelRecv{channel, v.Interface()} + } +} + +type channelClosed struct { + channel interface{} +} + +type channelRecv struct { + channel interface{} + value interface{} +} + +// readyEvents returns the subset of events that are ready. +func readyEvents(events []*Event) ([]*Event, os.Error) { + ready := make([]*Event, len(events)) + + j := 0 + eventsWaiting := false + for _, event := range events { + if event.occurred { + continue + } + + eventsWaiting = true + if event.isReady() { + ready[j] = event + j++ + } + } + + if j == 0 && eventsWaiting { + names := make([]string, len(events)) + for _, event := range events { + if event.occurred { + continue + } + names[j] = event.name + } + + return nil, SetupError("dependency cycle in events. These events are waiting to run but cannot: " + strings.Join(names, ", ")) + } + + return ready[0:j], nil +} diff --git a/src/cmd/gofix/testdata/reflect.template.go.in b/src/cmd/gofix/testdata/reflect.template.go.in new file mode 100644 index 000000000..ba06de4e3 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.template.go.in @@ -0,0 +1,1043 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + Data-driven templates for generating textual output such as + HTML. + + Templates are executed by applying them to a data structure. + Annotations in the template refer to elements of the data + structure (typically a field of a struct or a key in a map) + to control execution and derive values to be displayed. + The template walks the structure as it executes and the + "cursor" @ represents the value at the current location + in the structure. + + Data items may be values or pointers; the interface hides the + indirection. + + In the following, 'field' is one of several things, according to the data. + + - The name of a field of a struct (result = data.field), + - The value stored in a map under that key (result = data[field]), or + - The result of invoking a niladic single-valued method with that name + (result = data.field()) + + Major constructs ({} are the default delimiters for template actions; + [] are the notation in this comment for optional elements): + + {# comment } + + A one-line comment. + + {.section field} XXX [ {.or} YYY ] {.end} + + Set @ to the value of the field. It may be an explicit @ + to stay at the same point in the data. If the field is nil + or empty, execute YYY; otherwise execute XXX. + + {.repeated section field} XXX [ {.alternates with} ZZZ ] [ {.or} YYY ] {.end} + + Like .section, but field must be an array or slice. XXX + is executed for each element. If the array is nil or empty, + YYY is executed instead. If the {.alternates with} marker + is present, ZZZ is executed between iterations of XXX. + + {field} + {field1 field2 ...} + {field|formatter} + {field1 field2...|formatter} + {field|formatter1|formatter2} + + Insert the value of the fields into the output. Each field is + first looked for in the cursor, as in .section and .repeated. + If it is not found, the search continues in outer sections + until the top level is reached. + + If the field value is a pointer, leading asterisks indicate + that the value to be inserted should be evaluated through the + pointer. For example, if x.p is of type *int, {x.p} will + insert the value of the pointer but {*x.p} will insert the + value of the underlying integer. If the value is nil or not a + pointer, asterisks have no effect. + + If a formatter is specified, it must be named in the formatter + map passed to the template set up routines or in the default + set ("html","str","") and is used to process the data for + output. The formatter function has signature + func(wr io.Writer, formatter string, data ...interface{}) + where wr is the destination for output, data holds the field + values at the instantiation, and formatter is its name at + the invocation site. The default formatter just concatenates + the string representations of the fields. + + Multiple formatters separated by the pipeline character | are + executed sequentially, with each formatter receiving the bytes + emitted by the one to its left. + + The delimiter strings get their default value, "{" and "}", from + JSON-template. They may be set to any non-empty, space-free + string using the SetDelims method. Their value can be printed + in the output using {.meta-left} and {.meta-right}. +*/ +package template + +import ( + "bytes" + "container/vector" + "fmt" + "io" + "io/ioutil" + "os" + "reflect" + "strings" + "unicode" + "utf8" +) + +// Errors returned during parsing and execution. Users may extract the information and reformat +// if they desire. +type Error struct { + Line int + Msg string +} + +func (e *Error) String() string { return fmt.Sprintf("line %d: %s", e.Line, e.Msg) } + +// Most of the literals are aces. +var lbrace = []byte{'{'} +var rbrace = []byte{'}'} +var space = []byte{' '} +var tab = []byte{'\t'} + +// The various types of "tokens", which are plain text or (usually) brace-delimited descriptors +const ( + tokAlternates = iota + tokComment + tokEnd + tokLiteral + tokOr + tokRepeated + tokSection + tokText + tokVariable +) + +// FormatterMap is the type describing the mapping from formatter +// names to the functions that implement them. +type FormatterMap map[string]func(io.Writer, string, ...interface{}) + +// Built-in formatters. +var builtins = FormatterMap{ + "html": HTMLFormatter, + "str": StringFormatter, + "": StringFormatter, +} + +// The parsed state of a template is a vector of xxxElement structs. +// Sections have line numbers so errors can be reported better during execution. + +// Plain text. +type textElement struct { + text []byte +} + +// A literal such as .meta-left or .meta-right +type literalElement struct { + text []byte +} + +// A variable invocation to be evaluated +type variableElement struct { + linenum int + word []string // The fields in the invocation. + fmts []string // Names of formatters to apply. len(fmts) > 0 +} + +// A .section block, possibly with a .or +type sectionElement struct { + linenum int // of .section itself + field string // cursor field for this block + start int // first element + or int // first element of .or block + end int // one beyond last element +} + +// A .repeated block, possibly with a .or and a .alternates +type repeatedElement struct { + sectionElement // It has the same structure... + altstart int // ... except for alternates + altend int +} + +// Template is the type that represents a template definition. +// It is unchanged after parsing. +type Template struct { + fmap FormatterMap // formatters for variables + // Used during parsing: + ldelim, rdelim []byte // delimiters; default {} + buf []byte // input text to process + p int // position in buf + linenum int // position in input + // Parsed results: + elems *vector.Vector +} + +// Internal state for executing a Template. As we evaluate the struct, +// the data item descends into the fields associated with sections, etc. +// Parent is used to walk upwards to find variables higher in the tree. +type state struct { + parent *state // parent in hierarchy + data reflect.Value // the driver data for this section etc. + wr io.Writer // where to send output + buf [2]bytes.Buffer // alternating buffers used when chaining formatters +} + +func (parent *state) clone(data reflect.Value) *state { + return &state{parent: parent, data: data, wr: parent.wr} +} + +// New creates a new template with the specified formatter map (which +// may be nil) to define auxiliary functions for formatting variables. +func New(fmap FormatterMap) *Template { + t := new(Template) + t.fmap = fmap + t.ldelim = lbrace + t.rdelim = rbrace + t.elems = new(vector.Vector) + return t +} + +// Report error and stop executing. The line number must be provided explicitly. +func (t *Template) execError(st *state, line int, err string, args ...interface{}) { + panic(&Error{line, fmt.Sprintf(err, args...)}) +} + +// Report error, panic to terminate parsing. +// The line number comes from the template state. +func (t *Template) parseError(err string, args ...interface{}) { + panic(&Error{t.linenum, fmt.Sprintf(err, args...)}) +} + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// -- Lexical analysis + +// Is c a white space character? +func white(c uint8) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' } + +// Safely, does s[n:n+len(t)] == t? +func equal(s []byte, n int, t []byte) bool { + b := s[n:] + if len(t) > len(b) { // not enough space left for a match. + return false + } + for i, c := range t { + if c != b[i] { + return false + } + } + return true +} + +// nextItem returns the next item from the input buffer. If the returned +// item is empty, we are at EOF. The item will be either a +// delimited string or a non-empty string between delimited +// strings. Tokens stop at (but include, if plain text) a newline. +// Action tokens on a line by themselves drop any space on +// either side, up to and including the newline. +func (t *Template) nextItem() []byte { + startOfLine := t.p == 0 || t.buf[t.p-1] == '\n' + start := t.p + var i int + newline := func() { + t.linenum++ + i++ + } + // Leading white space up to but not including newline + for i = start; i < len(t.buf); i++ { + if t.buf[i] == '\n' || !white(t.buf[i]) { + break + } + } + leadingSpace := i > start + // What's left is nothing, newline, delimited string, or plain text + switch { + case i == len(t.buf): + // EOF; nothing to do + case t.buf[i] == '\n': + newline() + case equal(t.buf, i, t.ldelim): + left := i // Start of left delimiter. + right := -1 // Will be (immediately after) right delimiter. + haveText := false // Delimiters contain text. + i += len(t.ldelim) + // Find the end of the action. + for ; i < len(t.buf); i++ { + if t.buf[i] == '\n' { + break + } + if equal(t.buf, i, t.rdelim) { + i += len(t.rdelim) + right = i + break + } + haveText = true + } + if right < 0 { + t.parseError("unmatched opening delimiter") + return nil + } + // Is this a special action (starts with '.' or '#') and the only thing on the line? + if startOfLine && haveText { + firstChar := t.buf[left+len(t.ldelim)] + if firstChar == '.' || firstChar == '#' { + // It's special and the first thing on the line. Is it the last? + for j := right; j < len(t.buf) && white(t.buf[j]); j++ { + if t.buf[j] == '\n' { + // Yes it is. Drop the surrounding space and return the {.foo} + t.linenum++ + t.p = j + 1 + return t.buf[left:right] + } + } + } + } + // No it's not. If there's leading space, return that. + if leadingSpace { + // not trimming space: return leading white space if there is some. + t.p = left + return t.buf[start:left] + } + // Return the word, leave the trailing space. + start = left + break + default: + for ; i < len(t.buf); i++ { + if t.buf[i] == '\n' { + newline() + break + } + if equal(t.buf, i, t.ldelim) { + break + } + } + } + item := t.buf[start:i] + t.p = i + return item +} + +// Turn a byte array into a white-space-split array of strings. +func words(buf []byte) []string { + s := make([]string, 0, 5) + p := 0 // position in buf + // one word per loop + for i := 0; ; i++ { + // skip white space + for ; p < len(buf) && white(buf[p]); p++ { + } + // grab word + start := p + for ; p < len(buf) && !white(buf[p]); p++ { + } + if start == p { // no text left + break + } + s = append(s, string(buf[start:p])) + } + return s +} + +// Analyze an item and return its token type and, if it's an action item, an array of +// its constituent words. +func (t *Template) analyze(item []byte) (tok int, w []string) { + // item is known to be non-empty + if !equal(item, 0, t.ldelim) { // doesn't start with left delimiter + tok = tokText + return + } + if !equal(item, len(item)-len(t.rdelim), t.rdelim) { // doesn't end with right delimiter + t.parseError("internal error: unmatched opening delimiter") // lexing should prevent this + return + } + if len(item) <= len(t.ldelim)+len(t.rdelim) { // no contents + t.parseError("empty directive") + return + } + // Comment + if item[len(t.ldelim)] == '#' { + tok = tokComment + return + } + // Split into words + w = words(item[len(t.ldelim) : len(item)-len(t.rdelim)]) // drop final delimiter + if len(w) == 0 { + t.parseError("empty directive") + return + } + if len(w) > 0 && w[0][0] != '.' { + tok = tokVariable + return + } + switch w[0] { + case ".meta-left", ".meta-right", ".space", ".tab": + tok = tokLiteral + return + case ".or": + tok = tokOr + return + case ".end": + tok = tokEnd + return + case ".section": + if len(w) != 2 { + t.parseError("incorrect fields for .section: %s", item) + return + } + tok = tokSection + return + case ".repeated": + if len(w) != 3 || w[1] != "section" { + t.parseError("incorrect fields for .repeated: %s", item) + return + } + tok = tokRepeated + return + case ".alternates": + if len(w) != 2 || w[1] != "with" { + t.parseError("incorrect fields for .alternates: %s", item) + return + } + tok = tokAlternates + return + } + t.parseError("bad directive: %s", item) + return +} + +// formatter returns the Formatter with the given name in the Template, or nil if none exists. +func (t *Template) formatter(name string) func(io.Writer, string, ...interface{}) { + if t.fmap != nil { + if fn := t.fmap[name]; fn != nil { + return fn + } + } + return builtins[name] +} + +// -- Parsing + +// Allocate a new variable-evaluation element. +func (t *Template) newVariable(words []string) *variableElement { + // After the final space-separated argument, formatters may be specified separated + // by pipe symbols, for example: {a b c|d|e} + + // Until we learn otherwise, formatters contains a single name: "", the default formatter. + formatters := []string{""} + lastWord := words[len(words)-1] + bar := strings.IndexRune(lastWord, '|') + if bar >= 0 { + words[len(words)-1] = lastWord[0:bar] + formatters = strings.Split(lastWord[bar+1:], "|", -1) + } + + // We could remember the function address here and avoid the lookup later, + // but it's more dynamic to let the user change the map contents underfoot. + // We do require the name to be present, though. + + // Is it in user-supplied map? + for _, f := range formatters { + if t.formatter(f) == nil { + t.parseError("unknown formatter: %q", f) + } + } + return &variableElement{t.linenum, words, formatters} +} + +// Grab the next item. If it's simple, just append it to the template. +// Otherwise return its details. +func (t *Template) parseSimple(item []byte) (done bool, tok int, w []string) { + tok, w = t.analyze(item) + done = true // assume for simplicity + switch tok { + case tokComment: + return + case tokText: + t.elems.Push(&textElement{item}) + return + case tokLiteral: + switch w[0] { + case ".meta-left": + t.elems.Push(&literalElement{t.ldelim}) + case ".meta-right": + t.elems.Push(&literalElement{t.rdelim}) + case ".space": + t.elems.Push(&literalElement{space}) + case ".tab": + t.elems.Push(&literalElement{tab}) + default: + t.parseError("internal error: unknown literal: %s", w[0]) + } + return + case tokVariable: + t.elems.Push(t.newVariable(w)) + return + } + return false, tok, w +} + +// parseRepeated and parseSection are mutually recursive + +func (t *Template) parseRepeated(words []string) *repeatedElement { + r := new(repeatedElement) + t.elems.Push(r) + r.linenum = t.linenum + r.field = words[2] + // Scan section, collecting true and false (.or) blocks. + r.start = t.elems.Len() + r.or = -1 + r.altstart = -1 + r.altend = -1 +Loop: + for { + item := t.nextItem() + if len(item) == 0 { + t.parseError("missing .end for .repeated section") + break + } + done, tok, w := t.parseSimple(item) + if done { + continue + } + switch tok { + case tokEnd: + break Loop + case tokOr: + if r.or >= 0 { + t.parseError("extra .or in .repeated section") + break Loop + } + r.altend = t.elems.Len() + r.or = t.elems.Len() + case tokSection: + t.parseSection(w) + case tokRepeated: + t.parseRepeated(w) + case tokAlternates: + if r.altstart >= 0 { + t.parseError("extra .alternates in .repeated section") + break Loop + } + if r.or >= 0 { + t.parseError(".alternates inside .or block in .repeated section") + break Loop + } + r.altstart = t.elems.Len() + default: + t.parseError("internal error: unknown repeated section item: %s", item) + break Loop + } + } + if r.altend < 0 { + r.altend = t.elems.Len() + } + r.end = t.elems.Len() + return r +} + +func (t *Template) parseSection(words []string) *sectionElement { + s := new(sectionElement) + t.elems.Push(s) + s.linenum = t.linenum + s.field = words[1] + // Scan section, collecting true and false (.or) blocks. + s.start = t.elems.Len() + s.or = -1 +Loop: + for { + item := t.nextItem() + if len(item) == 0 { + t.parseError("missing .end for .section") + break + } + done, tok, w := t.parseSimple(item) + if done { + continue + } + switch tok { + case tokEnd: + break Loop + case tokOr: + if s.or >= 0 { + t.parseError("extra .or in .section") + break Loop + } + s.or = t.elems.Len() + case tokSection: + t.parseSection(w) + case tokRepeated: + t.parseRepeated(w) + case tokAlternates: + t.parseError(".alternates not in .repeated") + default: + t.parseError("internal error: unknown section item: %s", item) + } + } + s.end = t.elems.Len() + return s +} + +func (t *Template) parse() { + for { + item := t.nextItem() + if len(item) == 0 { + break + } + done, tok, w := t.parseSimple(item) + if done { + continue + } + switch tok { + case tokOr, tokEnd, tokAlternates: + t.parseError("unexpected %s", w[0]) + case tokSection: + t.parseSection(w) + case tokRepeated: + t.parseRepeated(w) + default: + t.parseError("internal error: bad directive in parse: %s", item) + } + } +} + +// -- Execution + +// Evaluate interfaces and pointers looking for a value that can look up the name, via a +// struct field, method, or map key, and return the result of the lookup. +func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value { + for v != nil { + typ := v.Type() + if n := v.Type().NumMethod(); n > 0 { + for i := 0; i < n; i++ { + m := typ.Method(i) + mtyp := m.Type + if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 { + if !isExported(name) { + t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type()) + } + return v.Method(i).Call(nil)[0] + } + } + } + switch av := v.(type) { + case *reflect.PtrValue: + v = av.Elem() + case *reflect.InterfaceValue: + v = av.Elem() + case *reflect.StructValue: + if !isExported(name) { + t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type()) + } + return av.FieldByName(name) + case *reflect.MapValue: + if v := av.Elem(reflect.NewValue(name)); v != nil { + return v + } + return reflect.MakeZero(typ.(*reflect.MapType).Elem()) + default: + return nil + } + } + return v +} + +// indirectPtr returns the item numLevels levels of indirection below the value. +// It is forgiving: if the value is not a pointer, it returns it rather than giving +// an error. If the pointer is nil, it is returned as is. +func indirectPtr(v reflect.Value, numLevels int) reflect.Value { + for i := numLevels; v != nil && i > 0; i++ { + if p, ok := v.(*reflect.PtrValue); ok { + if p.IsNil() { + return v + } + v = p.Elem() + } else { + break + } + } + return v +} + +// Walk v through pointers and interfaces, extracting the elements within. +func indirect(v reflect.Value) reflect.Value { +loop: + for v != nil { + switch av := v.(type) { + case *reflect.PtrValue: + v = av.Elem() + case *reflect.InterfaceValue: + v = av.Elem() + default: + break loop + } + } + return v +} + +// If the data for this template is a struct, find the named variable. +// Names of the form a.b.c are walked down the data tree. +// The special name "@" (the "cursor") denotes the current data. +// The value coming in (st.data) might need indirecting to reach +// a struct while the return value is not indirected - that is, +// it represents the actual named field. Leading stars indicate +// levels of indirection to be applied to the value. +func (t *Template) findVar(st *state, s string) reflect.Value { + data := st.data + flattenedName := strings.TrimLeft(s, "*") + numStars := len(s) - len(flattenedName) + s = flattenedName + if s == "@" { + return indirectPtr(data, numStars) + } + for _, elem := range strings.Split(s, ".", -1) { + // Look up field; data must be a struct or map. + data = t.lookup(st, data, elem) + if data == nil { + return nil + } + } + return indirectPtr(data, numStars) +} + +// Is there no data to look at? +func empty(v reflect.Value) bool { + v = indirect(v) + if v == nil { + return true + } + switch v := v.(type) { + case *reflect.BoolValue: + return v.Get() == false + case *reflect.StringValue: + return v.Get() == "" + case *reflect.StructValue: + return false + case *reflect.MapValue: + return false + case *reflect.ArrayValue: + return v.Len() == 0 + case *reflect.SliceValue: + return v.Len() == 0 + } + return false +} + +// Look up a variable or method, up through the parent if necessary. +func (t *Template) varValue(name string, st *state) reflect.Value { + field := t.findVar(st, name) + if field == nil { + if st.parent == nil { + t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type()) + } + return t.varValue(name, st.parent) + } + return field +} + +func (t *Template) format(wr io.Writer, fmt string, val []interface{}, v *variableElement, st *state) { + fn := t.formatter(fmt) + if fn == nil { + t.execError(st, v.linenum, "missing formatter %s for variable %s", fmt, v.word[0]) + } + fn(wr, fmt, val...) +} + +// Evaluate a variable, looking up through the parent if necessary. +// If it has a formatter attached ({var|formatter}) run that too. +func (t *Template) writeVariable(v *variableElement, st *state) { + // Turn the words of the invocation into values. + val := make([]interface{}, len(v.word)) + for i, word := range v.word { + val[i] = t.varValue(word, st).Interface() + } + + for i, fmt := range v.fmts[:len(v.fmts)-1] { + b := &st.buf[i&1] + b.Reset() + t.format(b, fmt, val, v, st) + val = val[0:1] + val[0] = b.Bytes() + } + t.format(st.wr, v.fmts[len(v.fmts)-1], val, v, st) +} + +// Execute element i. Return next index to execute. +func (t *Template) executeElement(i int, st *state) int { + switch elem := t.elems.At(i).(type) { + case *textElement: + st.wr.Write(elem.text) + return i + 1 + case *literalElement: + st.wr.Write(elem.text) + return i + 1 + case *variableElement: + t.writeVariable(elem, st) + return i + 1 + case *sectionElement: + t.executeSection(elem, st) + return elem.end + case *repeatedElement: + t.executeRepeated(elem, st) + return elem.end + } + e := t.elems.At(i) + t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e) + return 0 +} + +// Execute the template. +func (t *Template) execute(start, end int, st *state) { + for i := start; i < end; { + i = t.executeElement(i, st) + } +} + +// Execute a .section +func (t *Template) executeSection(s *sectionElement, st *state) { + // Find driver data for this section. It must be in the current struct. + field := t.varValue(s.field, st) + if field == nil { + t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type()) + } + st = st.clone(field) + start, end := s.start, s.or + if !empty(field) { + // Execute the normal block. + if end < 0 { + end = s.end + } + } else { + // Execute the .or block. If it's missing, do nothing. + start, end = s.or, s.end + if start < 0 { + return + } + } + for i := start; i < end; { + i = t.executeElement(i, st) + } +} + +// Return the result of calling the Iter method on v, or nil. +func iter(v reflect.Value) *reflect.ChanValue { + for j := 0; j < v.Type().NumMethod(); j++ { + mth := v.Type().Method(j) + fv := v.Method(j) + ft := fv.Type().(*reflect.FuncType) + // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue. + if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 { + continue + } + ct, ok := ft.Out(0).(*reflect.ChanType) + if !ok || ct.Dir()&reflect.RecvDir == 0 { + continue + } + return fv.Call(nil)[0].(*reflect.ChanValue) + } + return nil +} + +// Execute a .repeated section +func (t *Template) executeRepeated(r *repeatedElement, st *state) { + // Find driver data for this section. It must be in the current struct. + field := t.varValue(r.field, st) + if field == nil { + t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type()) + } + field = indirect(field) + + start, end := r.start, r.or + if end < 0 { + end = r.end + } + if r.altstart >= 0 { + end = r.altstart + } + first := true + + // Code common to all the loops. + loopBody := func(newst *state) { + // .alternates between elements + if !first && r.altstart >= 0 { + for i := r.altstart; i < r.altend; { + i = t.executeElement(i, newst) + } + } + first = false + for i := start; i < end; { + i = t.executeElement(i, newst) + } + } + + if array, ok := field.(reflect.ArrayOrSliceValue); ok { + for j := 0; j < array.Len(); j++ { + loopBody(st.clone(array.Elem(j))) + } + } else if m, ok := field.(*reflect.MapValue); ok { + for _, key := range m.Keys() { + loopBody(st.clone(m.Elem(key))) + } + } else if ch := iter(field); ch != nil { + for { + e, ok := ch.Recv() + if !ok { + break + } + loopBody(st.clone(e)) + } + } else { + t.execError(st, r.linenum, ".repeated: cannot repeat %s (type %s)", + r.field, field.Type()) + } + + if first { + // Empty. Execute the .or block, once. If it's missing, do nothing. + start, end := r.or, r.end + if start >= 0 { + newst := st.clone(field) + for i := start; i < end; { + i = t.executeElement(i, newst) + } + } + return + } +} + +// A valid delimiter must contain no white space and be non-empty. +func validDelim(d []byte) bool { + if len(d) == 0 { + return false + } + for _, c := range d { + if white(c) { + return false + } + } + return true +} + +// checkError is a deferred function to turn a panic with type *Error into a plain error return. +// Other panics are unexpected and so are re-enabled. +func checkError(error *os.Error) { + if v := recover(); v != nil { + if e, ok := v.(*Error); ok { + *error = e + } else { + // runtime errors should crash + panic(v) + } + } +} + +// -- Public interface + +// Parse initializes a Template by parsing its definition. The string +// s contains the template text. If any errors occur, Parse returns +// the error. +func (t *Template) Parse(s string) (err os.Error) { + if t.elems == nil { + return &Error{1, "template not allocated with New"} + } + if !validDelim(t.ldelim) || !validDelim(t.rdelim) { + return &Error{1, fmt.Sprintf("bad delimiter strings %q %q", t.ldelim, t.rdelim)} + } + defer checkError(&err) + t.buf = []byte(s) + t.p = 0 + t.linenum = 1 + t.parse() + return nil +} + +// ParseFile is like Parse but reads the template definition from the +// named file. +func (t *Template) ParseFile(filename string) (err os.Error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + return t.Parse(string(b)) +} + +// Execute applies a parsed template to the specified data object, +// generating output to wr. +func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) { + // Extract the driver data. + val := reflect.NewValue(data) + defer checkError(&err) + t.p = 0 + t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr}) + return nil +} + +// SetDelims sets the left and right delimiters for operations in the +// template. They are validated during parsing. They could be +// validated here but it's better to keep the routine simple. The +// delimiters are very rarely invalid and Parse has the necessary +// error-handling interface already. +func (t *Template) SetDelims(left, right string) { + t.ldelim = []byte(left) + t.rdelim = []byte(right) +} + +// Parse creates a Template with default parameters (such as {} for +// metacharacters). The string s contains the template text while +// the formatter map fmap, which may be nil, defines auxiliary functions +// for formatting variables. The template is returned. If any errors +// occur, err will be non-nil. +func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) { + t = New(fmap) + err = t.Parse(s) + if err != nil { + t = nil + } + return +} + +// ParseFile is a wrapper function that creates a Template with default +// parameters (such as {} for metacharacters). The filename identifies +// a file containing the template text, while the formatter map fmap, which +// may be nil, defines auxiliary functions for formatting variables. +// The template is returned. If any errors occur, err will be non-nil. +func ParseFile(filename string, fmap FormatterMap) (t *Template, err os.Error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return Parse(string(b), fmap) +} + +// MustParse is like Parse but panics if the template cannot be parsed. +func MustParse(s string, fmap FormatterMap) *Template { + t, err := Parse(s, fmap) + if err != nil { + panic("template.MustParse error: " + err.String()) + } + return t +} + +// MustParseFile is like ParseFile but panics if the file cannot be read +// or the template cannot be parsed. +func MustParseFile(filename string, fmap FormatterMap) *Template { + b, err := ioutil.ReadFile(filename) + if err != nil { + panic("template.MustParseFile error: " + err.String()) + } + return MustParse(string(b), fmap) +} diff --git a/src/cmd/gofix/testdata/reflect.template.go.out b/src/cmd/gofix/testdata/reflect.template.go.out new file mode 100644 index 000000000..28872dbee --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.template.go.out @@ -0,0 +1,1044 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + Data-driven templates for generating textual output such as + HTML. + + Templates are executed by applying them to a data structure. + Annotations in the template refer to elements of the data + structure (typically a field of a struct or a key in a map) + to control execution and derive values to be displayed. + The template walks the structure as it executes and the + "cursor" @ represents the value at the current location + in the structure. + + Data items may be values or pointers; the interface hides the + indirection. + + In the following, 'field' is one of several things, according to the data. + + - The name of a field of a struct (result = data.field), + - The value stored in a map under that key (result = data[field]), or + - The result of invoking a niladic single-valued method with that name + (result = data.field()) + + Major constructs ({} are the default delimiters for template actions; + [] are the notation in this comment for optional elements): + + {# comment } + + A one-line comment. + + {.section field} XXX [ {.or} YYY ] {.end} + + Set @ to the value of the field. It may be an explicit @ + to stay at the same point in the data. If the field is nil + or empty, execute YYY; otherwise execute XXX. + + {.repeated section field} XXX [ {.alternates with} ZZZ ] [ {.or} YYY ] {.end} + + Like .section, but field must be an array or slice. XXX + is executed for each element. If the array is nil or empty, + YYY is executed instead. If the {.alternates with} marker + is present, ZZZ is executed between iterations of XXX. + + {field} + {field1 field2 ...} + {field|formatter} + {field1 field2...|formatter} + {field|formatter1|formatter2} + + Insert the value of the fields into the output. Each field is + first looked for in the cursor, as in .section and .repeated. + If it is not found, the search continues in outer sections + until the top level is reached. + + If the field value is a pointer, leading asterisks indicate + that the value to be inserted should be evaluated through the + pointer. For example, if x.p is of type *int, {x.p} will + insert the value of the pointer but {*x.p} will insert the + value of the underlying integer. If the value is nil or not a + pointer, asterisks have no effect. + + If a formatter is specified, it must be named in the formatter + map passed to the template set up routines or in the default + set ("html","str","") and is used to process the data for + output. The formatter function has signature + func(wr io.Writer, formatter string, data ...interface{}) + where wr is the destination for output, data holds the field + values at the instantiation, and formatter is its name at + the invocation site. The default formatter just concatenates + the string representations of the fields. + + Multiple formatters separated by the pipeline character | are + executed sequentially, with each formatter receiving the bytes + emitted by the one to its left. + + The delimiter strings get their default value, "{" and "}", from + JSON-template. They may be set to any non-empty, space-free + string using the SetDelims method. Their value can be printed + in the output using {.meta-left} and {.meta-right}. +*/ +package template + +import ( + "bytes" + "container/vector" + "fmt" + "io" + "io/ioutil" + "os" + "reflect" + "strings" + "unicode" + "utf8" +) + +// Errors returned during parsing and execution. Users may extract the information and reformat +// if they desire. +type Error struct { + Line int + Msg string +} + +func (e *Error) String() string { return fmt.Sprintf("line %d: %s", e.Line, e.Msg) } + +// Most of the literals are aces. +var lbrace = []byte{'{'} +var rbrace = []byte{'}'} +var space = []byte{' '} +var tab = []byte{'\t'} + +// The various types of "tokens", which are plain text or (usually) brace-delimited descriptors +const ( + tokAlternates = iota + tokComment + tokEnd + tokLiteral + tokOr + tokRepeated + tokSection + tokText + tokVariable +) + +// FormatterMap is the type describing the mapping from formatter +// names to the functions that implement them. +type FormatterMap map[string]func(io.Writer, string, ...interface{}) + +// Built-in formatters. +var builtins = FormatterMap{ + "html": HTMLFormatter, + "str": StringFormatter, + "": StringFormatter, +} + +// The parsed state of a template is a vector of xxxElement structs. +// Sections have line numbers so errors can be reported better during execution. + +// Plain text. +type textElement struct { + text []byte +} + +// A literal such as .meta-left or .meta-right +type literalElement struct { + text []byte +} + +// A variable invocation to be evaluated +type variableElement struct { + linenum int + word []string // The fields in the invocation. + fmts []string // Names of formatters to apply. len(fmts) > 0 +} + +// A .section block, possibly with a .or +type sectionElement struct { + linenum int // of .section itself + field string // cursor field for this block + start int // first element + or int // first element of .or block + end int // one beyond last element +} + +// A .repeated block, possibly with a .or and a .alternates +type repeatedElement struct { + sectionElement // It has the same structure... + altstart int // ... except for alternates + altend int +} + +// Template is the type that represents a template definition. +// It is unchanged after parsing. +type Template struct { + fmap FormatterMap // formatters for variables + // Used during parsing: + ldelim, rdelim []byte // delimiters; default {} + buf []byte // input text to process + p int // position in buf + linenum int // position in input + // Parsed results: + elems *vector.Vector +} + +// Internal state for executing a Template. As we evaluate the struct, +// the data item descends into the fields associated with sections, etc. +// Parent is used to walk upwards to find variables higher in the tree. +type state struct { + parent *state // parent in hierarchy + data reflect.Value // the driver data for this section etc. + wr io.Writer // where to send output + buf [2]bytes.Buffer // alternating buffers used when chaining formatters +} + +func (parent *state) clone(data reflect.Value) *state { + return &state{parent: parent, data: data, wr: parent.wr} +} + +// New creates a new template with the specified formatter map (which +// may be nil) to define auxiliary functions for formatting variables. +func New(fmap FormatterMap) *Template { + t := new(Template) + t.fmap = fmap + t.ldelim = lbrace + t.rdelim = rbrace + t.elems = new(vector.Vector) + return t +} + +// Report error and stop executing. The line number must be provided explicitly. +func (t *Template) execError(st *state, line int, err string, args ...interface{}) { + panic(&Error{line, fmt.Sprintf(err, args...)}) +} + +// Report error, panic to terminate parsing. +// The line number comes from the template state. +func (t *Template) parseError(err string, args ...interface{}) { + panic(&Error{t.linenum, fmt.Sprintf(err, args...)}) +} + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// -- Lexical analysis + +// Is c a white space character? +func white(c uint8) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' } + +// Safely, does s[n:n+len(t)] == t? +func equal(s []byte, n int, t []byte) bool { + b := s[n:] + if len(t) > len(b) { // not enough space left for a match. + return false + } + for i, c := range t { + if c != b[i] { + return false + } + } + return true +} + +// nextItem returns the next item from the input buffer. If the returned +// item is empty, we are at EOF. The item will be either a +// delimited string or a non-empty string between delimited +// strings. Tokens stop at (but include, if plain text) a newline. +// Action tokens on a line by themselves drop any space on +// either side, up to and including the newline. +func (t *Template) nextItem() []byte { + startOfLine := t.p == 0 || t.buf[t.p-1] == '\n' + start := t.p + var i int + newline := func() { + t.linenum++ + i++ + } + // Leading white space up to but not including newline + for i = start; i < len(t.buf); i++ { + if t.buf[i] == '\n' || !white(t.buf[i]) { + break + } + } + leadingSpace := i > start + // What's left is nothing, newline, delimited string, or plain text + switch { + case i == len(t.buf): + // EOF; nothing to do + case t.buf[i] == '\n': + newline() + case equal(t.buf, i, t.ldelim): + left := i // Start of left delimiter. + right := -1 // Will be (immediately after) right delimiter. + haveText := false // Delimiters contain text. + i += len(t.ldelim) + // Find the end of the action. + for ; i < len(t.buf); i++ { + if t.buf[i] == '\n' { + break + } + if equal(t.buf, i, t.rdelim) { + i += len(t.rdelim) + right = i + break + } + haveText = true + } + if right < 0 { + t.parseError("unmatched opening delimiter") + return nil + } + // Is this a special action (starts with '.' or '#') and the only thing on the line? + if startOfLine && haveText { + firstChar := t.buf[left+len(t.ldelim)] + if firstChar == '.' || firstChar == '#' { + // It's special and the first thing on the line. Is it the last? + for j := right; j < len(t.buf) && white(t.buf[j]); j++ { + if t.buf[j] == '\n' { + // Yes it is. Drop the surrounding space and return the {.foo} + t.linenum++ + t.p = j + 1 + return t.buf[left:right] + } + } + } + } + // No it's not. If there's leading space, return that. + if leadingSpace { + // not trimming space: return leading white space if there is some. + t.p = left + return t.buf[start:left] + } + // Return the word, leave the trailing space. + start = left + break + default: + for ; i < len(t.buf); i++ { + if t.buf[i] == '\n' { + newline() + break + } + if equal(t.buf, i, t.ldelim) { + break + } + } + } + item := t.buf[start:i] + t.p = i + return item +} + +// Turn a byte array into a white-space-split array of strings. +func words(buf []byte) []string { + s := make([]string, 0, 5) + p := 0 // position in buf + // one word per loop + for i := 0; ; i++ { + // skip white space + for ; p < len(buf) && white(buf[p]); p++ { + } + // grab word + start := p + for ; p < len(buf) && !white(buf[p]); p++ { + } + if start == p { // no text left + break + } + s = append(s, string(buf[start:p])) + } + return s +} + +// Analyze an item and return its token type and, if it's an action item, an array of +// its constituent words. +func (t *Template) analyze(item []byte) (tok int, w []string) { + // item is known to be non-empty + if !equal(item, 0, t.ldelim) { // doesn't start with left delimiter + tok = tokText + return + } + if !equal(item, len(item)-len(t.rdelim), t.rdelim) { // doesn't end with right delimiter + t.parseError("internal error: unmatched opening delimiter") // lexing should prevent this + return + } + if len(item) <= len(t.ldelim)+len(t.rdelim) { // no contents + t.parseError("empty directive") + return + } + // Comment + if item[len(t.ldelim)] == '#' { + tok = tokComment + return + } + // Split into words + w = words(item[len(t.ldelim) : len(item)-len(t.rdelim)]) // drop final delimiter + if len(w) == 0 { + t.parseError("empty directive") + return + } + if len(w) > 0 && w[0][0] != '.' { + tok = tokVariable + return + } + switch w[0] { + case ".meta-left", ".meta-right", ".space", ".tab": + tok = tokLiteral + return + case ".or": + tok = tokOr + return + case ".end": + tok = tokEnd + return + case ".section": + if len(w) != 2 { + t.parseError("incorrect fields for .section: %s", item) + return + } + tok = tokSection + return + case ".repeated": + if len(w) != 3 || w[1] != "section" { + t.parseError("incorrect fields for .repeated: %s", item) + return + } + tok = tokRepeated + return + case ".alternates": + if len(w) != 2 || w[1] != "with" { + t.parseError("incorrect fields for .alternates: %s", item) + return + } + tok = tokAlternates + return + } + t.parseError("bad directive: %s", item) + return +} + +// formatter returns the Formatter with the given name in the Template, or nil if none exists. +func (t *Template) formatter(name string) func(io.Writer, string, ...interface{}) { + if t.fmap != nil { + if fn := t.fmap[name]; fn != nil { + return fn + } + } + return builtins[name] +} + +// -- Parsing + +// Allocate a new variable-evaluation element. +func (t *Template) newVariable(words []string) *variableElement { + // After the final space-separated argument, formatters may be specified separated + // by pipe symbols, for example: {a b c|d|e} + + // Until we learn otherwise, formatters contains a single name: "", the default formatter. + formatters := []string{""} + lastWord := words[len(words)-1] + bar := strings.IndexRune(lastWord, '|') + if bar >= 0 { + words[len(words)-1] = lastWord[0:bar] + formatters = strings.Split(lastWord[bar+1:], "|", -1) + } + + // We could remember the function address here and avoid the lookup later, + // but it's more dynamic to let the user change the map contents underfoot. + // We do require the name to be present, though. + + // Is it in user-supplied map? + for _, f := range formatters { + if t.formatter(f) == nil { + t.parseError("unknown formatter: %q", f) + } + } + return &variableElement{t.linenum, words, formatters} +} + +// Grab the next item. If it's simple, just append it to the template. +// Otherwise return its details. +func (t *Template) parseSimple(item []byte) (done bool, tok int, w []string) { + tok, w = t.analyze(item) + done = true // assume for simplicity + switch tok { + case tokComment: + return + case tokText: + t.elems.Push(&textElement{item}) + return + case tokLiteral: + switch w[0] { + case ".meta-left": + t.elems.Push(&literalElement{t.ldelim}) + case ".meta-right": + t.elems.Push(&literalElement{t.rdelim}) + case ".space": + t.elems.Push(&literalElement{space}) + case ".tab": + t.elems.Push(&literalElement{tab}) + default: + t.parseError("internal error: unknown literal: %s", w[0]) + } + return + case tokVariable: + t.elems.Push(t.newVariable(w)) + return + } + return false, tok, w +} + +// parseRepeated and parseSection are mutually recursive + +func (t *Template) parseRepeated(words []string) *repeatedElement { + r := new(repeatedElement) + t.elems.Push(r) + r.linenum = t.linenum + r.field = words[2] + // Scan section, collecting true and false (.or) blocks. + r.start = t.elems.Len() + r.or = -1 + r.altstart = -1 + r.altend = -1 +Loop: + for { + item := t.nextItem() + if len(item) == 0 { + t.parseError("missing .end for .repeated section") + break + } + done, tok, w := t.parseSimple(item) + if done { + continue + } + switch tok { + case tokEnd: + break Loop + case tokOr: + if r.or >= 0 { + t.parseError("extra .or in .repeated section") + break Loop + } + r.altend = t.elems.Len() + r.or = t.elems.Len() + case tokSection: + t.parseSection(w) + case tokRepeated: + t.parseRepeated(w) + case tokAlternates: + if r.altstart >= 0 { + t.parseError("extra .alternates in .repeated section") + break Loop + } + if r.or >= 0 { + t.parseError(".alternates inside .or block in .repeated section") + break Loop + } + r.altstart = t.elems.Len() + default: + t.parseError("internal error: unknown repeated section item: %s", item) + break Loop + } + } + if r.altend < 0 { + r.altend = t.elems.Len() + } + r.end = t.elems.Len() + return r +} + +func (t *Template) parseSection(words []string) *sectionElement { + s := new(sectionElement) + t.elems.Push(s) + s.linenum = t.linenum + s.field = words[1] + // Scan section, collecting true and false (.or) blocks. + s.start = t.elems.Len() + s.or = -1 +Loop: + for { + item := t.nextItem() + if len(item) == 0 { + t.parseError("missing .end for .section") + break + } + done, tok, w := t.parseSimple(item) + if done { + continue + } + switch tok { + case tokEnd: + break Loop + case tokOr: + if s.or >= 0 { + t.parseError("extra .or in .section") + break Loop + } + s.or = t.elems.Len() + case tokSection: + t.parseSection(w) + case tokRepeated: + t.parseRepeated(w) + case tokAlternates: + t.parseError(".alternates not in .repeated") + default: + t.parseError("internal error: unknown section item: %s", item) + } + } + s.end = t.elems.Len() + return s +} + +func (t *Template) parse() { + for { + item := t.nextItem() + if len(item) == 0 { + break + } + done, tok, w := t.parseSimple(item) + if done { + continue + } + switch tok { + case tokOr, tokEnd, tokAlternates: + t.parseError("unexpected %s", w[0]) + case tokSection: + t.parseSection(w) + case tokRepeated: + t.parseRepeated(w) + default: + t.parseError("internal error: bad directive in parse: %s", item) + } + } +} + +// -- Execution + +// Evaluate interfaces and pointers looking for a value that can look up the name, via a +// struct field, method, or map key, and return the result of the lookup. +func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value { + for v.IsValid() { + typ := v.Type() + if n := v.Type().NumMethod(); n > 0 { + for i := 0; i < n; i++ { + m := typ.Method(i) + mtyp := m.Type + if m.Name == name && mtyp.NumIn() == 1 && mtyp.NumOut() == 1 { + if !isExported(name) { + t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type()) + } + return v.Method(i).Call(nil)[0] + } + } + } + switch av := v; av.Kind() { + case reflect.Ptr: + v = av.Elem() + case reflect.Interface: + v = av.Elem() + case reflect.Struct: + if !isExported(name) { + t.execError(st, t.linenum, "name not exported: %s in type %s", name, st.data.Type()) + } + return av.FieldByName(name) + case reflect.Map: + if v := av.MapIndex(reflect.NewValue(name)); v.IsValid() { + return v + } + return reflect.Zero(typ.Elem()) + default: + return reflect.Value{} + } + } + return v +} + +// indirectPtr returns the item numLevels levels of indirection below the value. +// It is forgiving: if the value is not a pointer, it returns it rather than giving +// an error. If the pointer is nil, it is returned as is. +func indirectPtr(v reflect.Value, numLevels int) reflect.Value { + for i := numLevels; v.IsValid() && i > 0; i++ { + if p := v; p.Kind() == reflect.Ptr { + if p.IsNil() { + return v + } + v = p.Elem() + } else { + break + } + } + return v +} + +// Walk v through pointers and interfaces, extracting the elements within. +func indirect(v reflect.Value) reflect.Value { +loop: + for v.IsValid() { + switch av := v; av.Kind() { + case reflect.Ptr: + v = av.Elem() + case reflect.Interface: + v = av.Elem() + default: + break loop + } + } + return v +} + +// If the data for this template is a struct, find the named variable. +// Names of the form a.b.c are walked down the data tree. +// The special name "@" (the "cursor") denotes the current data. +// The value coming in (st.data) might need indirecting to reach +// a struct while the return value is not indirected - that is, +// it represents the actual named field. Leading stars indicate +// levels of indirection to be applied to the value. +func (t *Template) findVar(st *state, s string) reflect.Value { + data := st.data + flattenedName := strings.TrimLeft(s, "*") + numStars := len(s) - len(flattenedName) + s = flattenedName + if s == "@" { + return indirectPtr(data, numStars) + } + for _, elem := range strings.Split(s, ".", -1) { + // Look up field; data must be a struct or map. + data = t.lookup(st, data, elem) + if !data.IsValid() { + return reflect.Value{} + } + } + return indirectPtr(data, numStars) +} + +// Is there no data to look at? +func empty(v reflect.Value) bool { + v = indirect(v) + if !v.IsValid() { + return true + } + switch v.Kind() { + case reflect.Bool: + return v.Bool() == false + case reflect.String: + return v.String() == "" + case reflect.Struct: + return false + case reflect.Map: + return false + case reflect.Array: + return v.Len() == 0 + case reflect.Slice: + return v.Len() == 0 + } + return false +} + +// Look up a variable or method, up through the parent if necessary. +func (t *Template) varValue(name string, st *state) reflect.Value { + field := t.findVar(st, name) + if !field.IsValid() { + if st.parent == nil { + t.execError(st, t.linenum, "name not found: %s in type %s", name, st.data.Type()) + } + return t.varValue(name, st.parent) + } + return field +} + +func (t *Template) format(wr io.Writer, fmt string, val []interface{}, v *variableElement, st *state) { + fn := t.formatter(fmt) + if fn == nil { + t.execError(st, v.linenum, "missing formatter %s for variable %s", fmt, v.word[0]) + } + fn(wr, fmt, val...) +} + +// Evaluate a variable, looking up through the parent if necessary. +// If it has a formatter attached ({var|formatter}) run that too. +func (t *Template) writeVariable(v *variableElement, st *state) { + // Turn the words of the invocation into values. + val := make([]interface{}, len(v.word)) + for i, word := range v.word { + val[i] = t.varValue(word, st).Interface() + } + + for i, fmt := range v.fmts[:len(v.fmts)-1] { + b := &st.buf[i&1] + b.Reset() + t.format(b, fmt, val, v, st) + val = val[0:1] + val[0] = b.Bytes() + } + t.format(st.wr, v.fmts[len(v.fmts)-1], val, v, st) +} + +// Execute element i. Return next index to execute. +func (t *Template) executeElement(i int, st *state) int { + switch elem := t.elems.At(i).(type) { + case *textElement: + st.wr.Write(elem.text) + return i + 1 + case *literalElement: + st.wr.Write(elem.text) + return i + 1 + case *variableElement: + t.writeVariable(elem, st) + return i + 1 + case *sectionElement: + t.executeSection(elem, st) + return elem.end + case *repeatedElement: + t.executeRepeated(elem, st) + return elem.end + } + e := t.elems.At(i) + t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e) + return 0 +} + +// Execute the template. +func (t *Template) execute(start, end int, st *state) { + for i := start; i < end; { + i = t.executeElement(i, st) + } +} + +// Execute a .section +func (t *Template) executeSection(s *sectionElement, st *state) { + // Find driver data for this section. It must be in the current struct. + field := t.varValue(s.field, st) + if !field.IsValid() { + t.execError(st, s.linenum, ".section: cannot find field %s in %s", s.field, st.data.Type()) + } + st = st.clone(field) + start, end := s.start, s.or + if !empty(field) { + // Execute the normal block. + if end < 0 { + end = s.end + } + } else { + // Execute the .or block. If it's missing, do nothing. + start, end = s.or, s.end + if start < 0 { + return + } + } + for i := start; i < end; { + i = t.executeElement(i, st) + } +} + +// Return the result of calling the Iter method on v, or nil. +func iter(v reflect.Value) reflect.Value { + for j := 0; j < v.Type().NumMethod(); j++ { + mth := v.Type().Method(j) + fv := v.Method(j) + ft := fv.Type() + // TODO(rsc): NumIn() should return 0 here, because ft is from a curried FuncValue. + if mth.Name != "Iter" || ft.NumIn() != 1 || ft.NumOut() != 1 { + continue + } + ct := ft.Out(0) + if ct.Kind() != reflect.Chan || + ct.ChanDir()&reflect.RecvDir == 0 { + continue + } + return fv.Call(nil)[0] + } + return reflect.Value{} +} + +// Execute a .repeated section +func (t *Template) executeRepeated(r *repeatedElement, st *state) { + // Find driver data for this section. It must be in the current struct. + field := t.varValue(r.field, st) + if !field.IsValid() { + t.execError(st, r.linenum, ".repeated: cannot find field %s in %s", r.field, st.data.Type()) + } + field = indirect(field) + + start, end := r.start, r.or + if end < 0 { + end = r.end + } + if r.altstart >= 0 { + end = r.altstart + } + first := true + + // Code common to all the loops. + loopBody := func(newst *state) { + // .alternates between elements + if !first && r.altstart >= 0 { + for i := r.altstart; i < r.altend; { + i = t.executeElement(i, newst) + } + } + first = false + for i := start; i < end; { + i = t.executeElement(i, newst) + } + } + + if array := field; array.Kind() == reflect.Array || array.Kind() == reflect.Slice { + for j := 0; j < array.Len(); j++ { + loopBody(st.clone(array.Index(j))) + } + } else if m := field; m.Kind() == reflect.Map { + for _, key := range m.MapKeys() { + loopBody(st.clone(m.MapIndex(key))) + } + } else if ch := iter(field); ch.IsValid() { + for { + e, ok := ch.Recv() + if !ok { + break + } + loopBody(st.clone(e)) + } + } else { + t.execError(st, r.linenum, ".repeated: cannot repeat %s (type %s)", + r.field, field.Type()) + } + + if first { + // Empty. Execute the .or block, once. If it's missing, do nothing. + start, end := r.or, r.end + if start >= 0 { + newst := st.clone(field) + for i := start; i < end; { + i = t.executeElement(i, newst) + } + } + return + } +} + +// A valid delimiter must contain no white space and be non-empty. +func validDelim(d []byte) bool { + if len(d) == 0 { + return false + } + for _, c := range d { + if white(c) { + return false + } + } + return true +} + +// checkError is a deferred function to turn a panic with type *Error into a plain error return. +// Other panics are unexpected and so are re-enabled. +func checkError(error *os.Error) { + if v := recover(); v != nil { + if e, ok := v.(*Error); ok { + *error = e + } else { + // runtime errors should crash + panic(v) + } + } +} + +// -- Public interface + +// Parse initializes a Template by parsing its definition. The string +// s contains the template text. If any errors occur, Parse returns +// the error. +func (t *Template) Parse(s string) (err os.Error) { + if t.elems == nil { + return &Error{1, "template not allocated with New"} + } + if !validDelim(t.ldelim) || !validDelim(t.rdelim) { + return &Error{1, fmt.Sprintf("bad delimiter strings %q %q", t.ldelim, t.rdelim)} + } + defer checkError(&err) + t.buf = []byte(s) + t.p = 0 + t.linenum = 1 + t.parse() + return nil +} + +// ParseFile is like Parse but reads the template definition from the +// named file. +func (t *Template) ParseFile(filename string) (err os.Error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + return t.Parse(string(b)) +} + +// Execute applies a parsed template to the specified data object, +// generating output to wr. +func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) { + // Extract the driver data. + val := reflect.NewValue(data) + defer checkError(&err) + t.p = 0 + t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr}) + return nil +} + +// SetDelims sets the left and right delimiters for operations in the +// template. They are validated during parsing. They could be +// validated here but it's better to keep the routine simple. The +// delimiters are very rarely invalid and Parse has the necessary +// error-handling interface already. +func (t *Template) SetDelims(left, right string) { + t.ldelim = []byte(left) + t.rdelim = []byte(right) +} + +// Parse creates a Template with default parameters (such as {} for +// metacharacters). The string s contains the template text while +// the formatter map fmap, which may be nil, defines auxiliary functions +// for formatting variables. The template is returned. If any errors +// occur, err will be non-nil. +func Parse(s string, fmap FormatterMap) (t *Template, err os.Error) { + t = New(fmap) + err = t.Parse(s) + if err != nil { + t = nil + } + return +} + +// ParseFile is a wrapper function that creates a Template with default +// parameters (such as {} for metacharacters). The filename identifies +// a file containing the template text, while the formatter map fmap, which +// may be nil, defines auxiliary functions for formatting variables. +// The template is returned. If any errors occur, err will be non-nil. +func ParseFile(filename string, fmap FormatterMap) (t *Template, err os.Error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return Parse(string(b), fmap) +} + +// MustParse is like Parse but panics if the template cannot be parsed. +func MustParse(s string, fmap FormatterMap) *Template { + t, err := Parse(s, fmap) + if err != nil { + panic("template.MustParse error: " + err.String()) + } + return t +} + +// MustParseFile is like ParseFile but panics if the file cannot be read +// or the template cannot be parsed. +func MustParseFile(filename string, fmap FormatterMap) *Template { + b, err := ioutil.ReadFile(filename) + if err != nil { + panic("template.MustParseFile error: " + err.String()) + } + return MustParse(string(b), fmap) +} diff --git a/src/cmd/gofix/testdata/reflect.type.go.in b/src/cmd/gofix/testdata/reflect.type.go.in new file mode 100644 index 000000000..305d41980 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.type.go.in @@ -0,0 +1,789 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "fmt" + "os" + "reflect" + "sync" + "unicode" + "utf8" +) + +// userTypeInfo stores the information associated with a type the user has handed +// to the package. It's computed once and stored in a map keyed by reflection +// type. +type userTypeInfo struct { + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type + isGobEncoder bool // does the type implement GobEncoder? + isGobDecoder bool // does the type implement GobDecoder? + encIndir int8 // number of indirections to reach the receiver type; may be negative + decIndir int8 // number of indirections to reach the receiver type; may be negative +} + +var ( + // Protected by an RWMutex because we read it a lot and write + // it only when we see a new type, typically when compiling. + userTypeLock sync.RWMutex + userTypeCache = make(map[reflect.Type]*userTypeInfo) +) + +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { + userTypeLock.RLock() + ut = userTypeCache[rt] + userTypeLock.RUnlock() + if ut != nil { + return + } + // Now set the value under the write lock. + userTypeLock.Lock() + defer userTypeLock.Unlock() + if ut = userTypeCache[rt]; ut != nil { + // Lost the race; not a problem. + return + } + ut = new(userTypeInfo) + ut.base = rt + ut.user = rt + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt, ok := ut.base.(*reflect.PtrType) + if !ok { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.(*reflect.PtrType).Elem() + } + ut.indir++ + } + ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck) + ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck) + userTypeCache[rt] = ut + return +} + +const ( + gobEncodeMethodName = "GobEncode" + gobDecodeMethodName = "GobDecode" +) + +// implements returns whether the type implements the interface, as encoded +// in the check function. +func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool { + if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance + return false + } + return check(typ) +} + +// gobEncoderCheck makes the type assertion a boolean function. +func gobEncoderCheck(typ reflect.Type) bool { + _, ok := reflect.MakeZero(typ).Interface().(GobEncoder) + return ok +} + +// gobDecoderCheck makes the type assertion a boolean function. +func gobDecoderCheck(typ reflect.Type) bool { + _, ok := reflect.MakeZero(typ).Interface().(GobDecoder) + return ok +} + +// implementsInterface reports whether the type implements the +// interface. (The actual check is done through the provided function.) +// It also returns the number of indirections required to get to the +// implementation. +func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) { + if typ == nil { + return + } + rt := typ + // The type might be a pointer and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if implements(rt, check) { + return true, indir + } + if p, ok := rt.(*reflect.PtrType); ok { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. + if _, ok := typ.(*reflect.PtrType); !ok { + // Not a pointer, but does the pointer work? + if implements(reflect.PtrTo(typ), check) { + return true, -1 + } + } + return false, 0 +} + +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error(err) + } + return ut +} +// A typeId represents a gob Type as an integer that can be passed on the wire. +// Internally, typeIds are used as keys to a map to recover the underlying type info. +type typeId int32 + +var nextId typeId // incremented for each new type we build +var typeLock sync.Mutex // set while building a type +const firstUserId = 64 // lowest id number granted to user + +type gobType interface { + id() typeId + setId(id typeId) + name() string + string() string // not public; only for debugging + safeString(seen map[typeId]bool) string +} + +var types = make(map[reflect.Type]gobType) +var idToType = make(map[typeId]gobType) +var builtinIdToType map[typeId]gobType // set in init() after builtins are established + +func setTypeId(typ gobType) { + nextId++ + typ.setId(nextId) + idToType[nextId] = typ +} + +func (t typeId) gobType() gobType { + if t == 0 { + return nil + } + return idToType[t] +} + +// string returns the string representation of the type associated with the typeId. +func (t typeId) string() string { + if t.gobType() == nil { + return "<nil>" + } + return t.gobType().string() +} + +// Name returns the name of the type associated with the typeId. +func (t typeId) name() string { + if t.gobType() == nil { + return "<nil>" + } + return t.gobType().name() +} + +// Common elements of all types. +type CommonType struct { + Name string + Id typeId +} + +func (t *CommonType) id() typeId { return t.Id } + +func (t *CommonType) setId(id typeId) { t.Id = id } + +func (t *CommonType) string() string { return t.Name } + +func (t *CommonType) safeString(seen map[typeId]bool) string { + return t.Name +} + +func (t *CommonType) name() string { return t.Name } + +// Create and check predefined types +// The string for tBytes is "bytes" not "[]byte" to signify its specialness. + +var ( + // Primordial types, needed during initialization. + // Always passed as pointers so the interface{} type + // goes through without losing its interfaceness. + tBool = bootstrapType("bool", (*bool)(nil), 1) + tInt = bootstrapType("int", (*int)(nil), 2) + tUint = bootstrapType("uint", (*uint)(nil), 3) + tFloat = bootstrapType("float", (*float64)(nil), 4) + tBytes = bootstrapType("bytes", (*[]byte)(nil), 5) + tString = bootstrapType("string", (*string)(nil), 6) + tComplex = bootstrapType("complex", (*complex128)(nil), 7) + tInterface = bootstrapType("interface", (*interface{})(nil), 8) + // Reserve some Ids for compatible expansion + tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9) + tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10) + tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11) + tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12) + tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13) + tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14) + tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15) +) + +// Predefined because it's needed by the Decoder +var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) + +func init() { + // Some magic numbers to make sure there are no surprises. + checkId(16, tWireType) + checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id) + + builtinIdToType = make(map[typeId]gobType) + for k, v := range idToType { + builtinIdToType[k] = v + } + + // Move the id space upwards to allow for growth in the predefined world + // without breaking existing files. + if nextId > firstUserId { + panic(fmt.Sprintln("nextId too large:", nextId)) + } + nextId = firstUserId + registerBasics() + wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) +} + +// Array type +type arrayType struct { + CommonType + Elem typeId + Len int +} + +func newArrayType(name string) *arrayType { + a := &arrayType{CommonType{Name: name}, 0, 0} + return a +} + +func (a *arrayType) init(elem gobType, len int) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(a) + a.Elem = elem.id() + a.Len = len +} + +func (a *arrayType) safeString(seen map[typeId]bool) string { + if seen[a.Id] { + return a.Name + } + seen[a.Id] = true + return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen)) +} + +func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } + +// GobEncoder type (something that implements the GobEncoder interface) +type gobEncoderType struct { + CommonType +} + +func newGobEncoderType(name string) *gobEncoderType { + g := &gobEncoderType{CommonType{Name: name}} + setTypeId(g) + return g +} + +func (g *gobEncoderType) safeString(seen map[typeId]bool) string { + return g.Name +} + +func (g *gobEncoderType) string() string { return g.Name } + +// Map type +type mapType struct { + CommonType + Key typeId + Elem typeId +} + +func newMapType(name string) *mapType { + m := &mapType{CommonType{Name: name}, 0, 0} + return m +} + +func (m *mapType) init(key, elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(m) + m.Key = key.id() + m.Elem = elem.id() +} + +func (m *mapType) safeString(seen map[typeId]bool) string { + if seen[m.Id] { + return m.Name + } + seen[m.Id] = true + key := m.Key.gobType().safeString(seen) + elem := m.Elem.gobType().safeString(seen) + return fmt.Sprintf("map[%s]%s", key, elem) +} + +func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) } + +// Slice type +type sliceType struct { + CommonType + Elem typeId +} + +func newSliceType(name string) *sliceType { + s := &sliceType{CommonType{Name: name}, 0} + return s +} + +func (s *sliceType) init(elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(s) + s.Elem = elem.id() +} + +func (s *sliceType) safeString(seen map[typeId]bool) string { + if seen[s.Id] { + return s.Name + } + seen[s.Id] = true + return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen)) +} + +func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) } + +// Struct type +type fieldType struct { + Name string + Id typeId +} + +type structType struct { + CommonType + Field []*fieldType +} + +func (s *structType) safeString(seen map[typeId]bool) string { + if s == nil { + return "<nil>" + } + if _, ok := seen[s.Id]; ok { + return s.Name + } + seen[s.Id] = true + str := s.Name + " = struct { " + for _, f := range s.Field { + str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen)) + } + str += "}" + return str +} + +func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) } + +func newStructType(name string) *structType { + s := &structType{CommonType{Name: name}, nil} + // For historical reasons we set the id here rather than init. + // See the comment in newTypeObject for details. + setTypeId(s) + return s +} + +// newTypeObject allocates a gobType for the reflection type rt. +// Unless ut represents a GobEncoder, rt should be the base type +// of ut. +// This is only called from the encoding side. The decoding side +// works through typeIds and userTypeInfos alone. +func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { + // Does this type implement GobEncoder? + if ut.isGobEncoder { + return newGobEncoderType(name), nil + } + var err os.Error + var type0, type1 gobType + defer func() { + if err != nil { + types[rt] = nil, false + } + }() + // Install the top-level type before the subtypes (e.g. struct before + // fields) so recursive types can be constructed safely. + switch t := rt.(type) { + // All basic types are easy: they are predefined. + case *reflect.BoolType: + return tBool.gobType(), nil + + case *reflect.IntType: + return tInt.gobType(), nil + + case *reflect.UintType: + return tUint.gobType(), nil + + case *reflect.FloatType: + return tFloat.gobType(), nil + + case *reflect.ComplexType: + return tComplex.gobType(), nil + + case *reflect.StringType: + return tString.gobType(), nil + + case *reflect.InterfaceType: + return tInterface.gobType(), nil + + case *reflect.ArrayType: + at := newArrayType(name) + types[rt] = at + type0, err = getBaseType("", t.Elem()) + if err != nil { + return nil, err + } + // Historical aside: + // For arrays, maps, and slices, we set the type id after the elements + // are constructed. This is to retain the order of type id allocation after + // a fix made to handle recursive types, which changed the order in + // which types are built. Delaying the setting in this way preserves + // type ids while allowing recursive types to be described. Structs, + // done below, were already handling recursion correctly so they + // assign the top-level id before those of the field. + at.init(type0, t.Len()) + return at, nil + + case *reflect.MapType: + mt := newMapType(name) + types[rt] = mt + type0, err = getBaseType("", t.Key()) + if err != nil { + return nil, err + } + type1, err = getBaseType("", t.Elem()) + if err != nil { + return nil, err + } + mt.init(type0, type1) + return mt, nil + + case *reflect.SliceType: + // []byte == []uint8 is a special case + if t.Elem().Kind() == reflect.Uint8 { + return tBytes.gobType(), nil + } + st := newSliceType(name) + types[rt] = st + type0, err = getBaseType(t.Elem().Name(), t.Elem()) + if err != nil { + return nil, err + } + st.init(type0) + return st, nil + + case *reflect.StructType: + st := newStructType(name) + types[rt] = st + idToType[st.id()] = st + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !isExported(f.Name) { + continue + } + typ := userType(f.Type).base + tname := typ.Name() + if tname == "" { + t := userType(f.Type).base + tname = t.String() + } + gt, err := getBaseType(tname, f.Type) + if err != nil { + return nil, err + } + st.Field = append(st.Field, &fieldType{f.Name, gt.id()}) + } + return st, nil + + default: + return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()) + } + return nil, nil +} + +// isExported reports whether this is an exported - upper case - name. +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// getBaseType returns the Gob type describing the given reflect.Type's base type. +// typeLock must be held. +func getBaseType(name string, rt reflect.Type) (gobType, os.Error) { + ut := userType(rt) + return getType(name, ut, ut.base) +} + +// getType returns the Gob type describing the given reflect.Type. +// Should be called only when handling GobEncoders/Decoders, +// which may be pointers. All other types are handled through the +// base type, never a pointer. +// typeLock must be held. +func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { + typ, present := types[rt] + if present { + return typ, nil + } + typ, err := newTypeObject(name, ut, rt) + if err == nil { + types[rt] = typ + } + return typ, err +} + +func checkId(want, got typeId) { + if want != got { + fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want)) + panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string()) + } +} + +// used for building the basic types; called only from init(). the incoming +// interface always refers to a pointer. +func bootstrapType(name string, e interface{}, expect typeId) typeId { + rt := reflect.Typeof(e).(*reflect.PtrType).Elem() + _, present := types[rt] + if present { + panic("bootstrap type already present: " + name + ", " + rt.String()) + } + typ := &CommonType{Name: name} + types[rt] = typ + setTypeId(typ) + checkId(expect, nextId) + userType(rt) // might as well cache it now + return nextId +} + +// Representation of the information we send and receive about this type. +// Each value we send is preceded by its type definition: an encoded int. +// However, the very first time we send the value, we first send the pair +// (-id, wireType). +// For bootstrapping purposes, we assume that the recipient knows how +// to decode a wireType; it is exactly the wireType struct here, interpreted +// using the gob rules for sending a structure, except that we assume the +// ids for wireType and structType etc. are known. The relevant pieces +// are built in encode.go's init() function. +// To maintain binary compatibility, if you extend this type, always put +// the new fields last. +type wireType struct { + ArrayT *arrayType + SliceT *sliceType + StructT *structType + MapT *mapType + GobEncoderT *gobEncoderType +} + +func (w *wireType) string() string { + const unknown = "unknown type" + if w == nil { + return unknown + } + switch { + case w.ArrayT != nil: + return w.ArrayT.Name + case w.SliceT != nil: + return w.SliceT.Name + case w.StructT != nil: + return w.StructT.Name + case w.MapT != nil: + return w.MapT.Name + case w.GobEncoderT != nil: + return w.GobEncoderT.Name + } + return unknown +} + +type typeInfo struct { + id typeId + encoder *encEngine + wire *wireType +} + +var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock + +// typeLock must be held. +func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) { + rt := ut.base + if ut.isGobEncoder { + // We want the user type, not the base type. + rt = ut.user + } + info, ok := typeInfoMap[rt] + if ok { + return info, nil + } + info = new(typeInfo) + gt, err := getBaseType(rt.Name(), rt) + if err != nil { + return nil, err + } + info.id = gt.id() + + if ut.isGobEncoder { + userType, err := getType(rt.Name(), ut, rt) + if err != nil { + return nil, err + } + info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)} + typeInfoMap[ut.user] = info + return info, nil + } + + t := info.id.gobType() + switch typ := rt.(type) { + case *reflect.ArrayType: + info.wire = &wireType{ArrayT: t.(*arrayType)} + case *reflect.MapType: + info.wire = &wireType{MapT: t.(*mapType)} + case *reflect.SliceType: + // []byte == []uint8 is a special case handled separately + if typ.Elem().Kind() != reflect.Uint8 { + info.wire = &wireType{SliceT: t.(*sliceType)} + } + case *reflect.StructType: + info.wire = &wireType{StructT: t.(*structType)} + } + typeInfoMap[rt] = info + return info, nil +} + +// Called only when a panic is acceptable and unexpected. +func mustGetTypeInfo(rt reflect.Type) *typeInfo { + t, err := getTypeInfo(userType(rt)) + if err != nil { + panic("getTypeInfo: " + err.String()) + } + return t +} + +// GobEncoder is the interface describing data that provides its own +// representation for encoding values for transmission to a GobDecoder. +// A type that implements GobEncoder and GobDecoder has complete +// control over the representation of its data and may therefore +// contain things such as private fields, channels, and functions, +// which are not usually transmissable in gob streams. +// +// Note: Since gobs can be stored permanently, It is good design +// to guarantee the encoding used by a GobEncoder is stable as the +// software evolves. For instance, it might make sense for GobEncode +// to include a version number in the encoding. +type GobEncoder interface { + // GobEncode returns a byte slice representing the encoding of the + // receiver for transmission to a GobDecoder, usually of the same + // concrete type. + GobEncode() ([]byte, os.Error) +} + +// GobDecoder is the interface describing data that provides its own +// routine for decoding transmitted values sent by a GobEncoder. +type GobDecoder interface { + // GobDecode overwrites the receiver, which must be a pointer, + // with the value represented by the byte slice, which was written + // by GobEncode, usually for the same concrete type. + GobDecode([]byte) os.Error +} + +var ( + nameToConcreteType = make(map[string]reflect.Type) + concreteTypeToName = make(map[reflect.Type]string) +) + +// RegisterName is like Register but uses the provided name rather than the +// type's default. +func RegisterName(name string, value interface{}) { + if name == "" { + // reserved for nil + panic("attempt to register empty name") + } + base := userType(reflect.Typeof(value)).base + // Check for incompatible duplicates. + if t, ok := nameToConcreteType[name]; ok && t != base { + panic("gob: registering duplicate types for " + name) + } + if n, ok := concreteTypeToName[base]; ok && n != name { + panic("gob: registering duplicate names for " + base.String()) + } + // Store the name and type provided by the user.... + nameToConcreteType[name] = reflect.Typeof(value) + // but the flattened type in the type table, since that's what decode needs. + concreteTypeToName[base] = name +} + +// Register records a type, identified by a value for that type, under its +// internal type name. That name will identify the concrete type of a value +// sent or received as an interface variable. Only types that will be +// transferred as implementations of interface values need to be registered. +// Expecting to be used only during initialization, it panics if the mapping +// between types and names is not a bijection. +func Register(value interface{}) { + // Default to printed representation for unnamed types + rt := reflect.Typeof(value) + name := rt.String() + + // But for named types (or pointers to them), qualify with import path. + // Dereference one pointer looking for a named type. + star := "" + if rt.Name() == "" { + if pt, ok := rt.(*reflect.PtrType); ok { + star = "*" + rt = pt + } + } + if rt.Name() != "" { + if rt.PkgPath() == "" { + name = star + rt.Name() + } else { + name = star + rt.PkgPath() + "." + rt.Name() + } + } + + RegisterName(name, value) +} + +func registerBasics() { + Register(int(0)) + Register(int8(0)) + Register(int16(0)) + Register(int32(0)) + Register(int64(0)) + Register(uint(0)) + Register(uint8(0)) + Register(uint16(0)) + Register(uint32(0)) + Register(uint64(0)) + Register(float32(0)) + Register(float64(0)) + Register(complex64(0i)) + Register(complex128(0i)) + Register(false) + Register("") + Register([]byte(nil)) +} diff --git a/src/cmd/gofix/testdata/reflect.type.go.out b/src/cmd/gofix/testdata/reflect.type.go.out new file mode 100644 index 000000000..8fd174841 --- /dev/null +++ b/src/cmd/gofix/testdata/reflect.type.go.out @@ -0,0 +1,789 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "fmt" + "os" + "reflect" + "sync" + "unicode" + "utf8" +) + +// userTypeInfo stores the information associated with a type the user has handed +// to the package. It's computed once and stored in a map keyed by reflection +// type. +type userTypeInfo struct { + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type + isGobEncoder bool // does the type implement GobEncoder? + isGobDecoder bool // does the type implement GobDecoder? + encIndir int8 // number of indirections to reach the receiver type; may be negative + decIndir int8 // number of indirections to reach the receiver type; may be negative +} + +var ( + // Protected by an RWMutex because we read it a lot and write + // it only when we see a new type, typically when compiling. + userTypeLock sync.RWMutex + userTypeCache = make(map[reflect.Type]*userTypeInfo) +) + +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { + userTypeLock.RLock() + ut = userTypeCache[rt] + userTypeLock.RUnlock() + if ut != nil { + return + } + // Now set the value under the write lock. + userTypeLock.Lock() + defer userTypeLock.Unlock() + if ut = userTypeCache[rt]; ut != nil { + // Lost the race; not a problem. + return + } + ut = new(userTypeInfo) + ut.base = rt + ut.user = rt + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt := ut.base + if pt.Kind() != reflect.Ptr { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.Elem() + } + ut.indir++ + } + ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck) + ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck) + userTypeCache[rt] = ut + return +} + +const ( + gobEncodeMethodName = "GobEncode" + gobDecodeMethodName = "GobDecode" +) + +// implements returns whether the type implements the interface, as encoded +// in the check function. +func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool { + if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance + return false + } + return check(typ) +} + +// gobEncoderCheck makes the type assertion a boolean function. +func gobEncoderCheck(typ reflect.Type) bool { + _, ok := reflect.Zero(typ).Interface().(GobEncoder) + return ok +} + +// gobDecoderCheck makes the type assertion a boolean function. +func gobDecoderCheck(typ reflect.Type) bool { + _, ok := reflect.Zero(typ).Interface().(GobDecoder) + return ok +} + +// implementsInterface reports whether the type implements the +// interface. (The actual check is done through the provided function.) +// It also returns the number of indirections required to get to the +// implementation. +func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) { + if typ == nil { + return + } + rt := typ + // The type might be a pointer and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if implements(rt, check) { + return true, indir + } + if p := rt; p.Kind() == reflect.Ptr { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. + if typ.Kind() != reflect.Ptr { + // Not a pointer, but does the pointer work? + if implements(reflect.PtrTo(typ), check) { + return true, -1 + } + } + return false, 0 +} + +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error(err) + } + return ut +} +// A typeId represents a gob Type as an integer that can be passed on the wire. +// Internally, typeIds are used as keys to a map to recover the underlying type info. +type typeId int32 + +var nextId typeId // incremented for each new type we build +var typeLock sync.Mutex // set while building a type +const firstUserId = 64 // lowest id number granted to user + +type gobType interface { + id() typeId + setId(id typeId) + name() string + string() string // not public; only for debugging + safeString(seen map[typeId]bool) string +} + +var types = make(map[reflect.Type]gobType) +var idToType = make(map[typeId]gobType) +var builtinIdToType map[typeId]gobType // set in init() after builtins are established + +func setTypeId(typ gobType) { + nextId++ + typ.setId(nextId) + idToType[nextId] = typ +} + +func (t typeId) gobType() gobType { + if t == 0 { + return nil + } + return idToType[t] +} + +// string returns the string representation of the type associated with the typeId. +func (t typeId) string() string { + if t.gobType() == nil { + return "<nil>" + } + return t.gobType().string() +} + +// Name returns the name of the type associated with the typeId. +func (t typeId) name() string { + if t.gobType() == nil { + return "<nil>" + } + return t.gobType().name() +} + +// Common elements of all types. +type CommonType struct { + Name string + Id typeId +} + +func (t *CommonType) id() typeId { return t.Id } + +func (t *CommonType) setId(id typeId) { t.Id = id } + +func (t *CommonType) string() string { return t.Name } + +func (t *CommonType) safeString(seen map[typeId]bool) string { + return t.Name +} + +func (t *CommonType) name() string { return t.Name } + +// Create and check predefined types +// The string for tBytes is "bytes" not "[]byte" to signify its specialness. + +var ( + // Primordial types, needed during initialization. + // Always passed as pointers so the interface{} type + // goes through without losing its interfaceness. + tBool = bootstrapType("bool", (*bool)(nil), 1) + tInt = bootstrapType("int", (*int)(nil), 2) + tUint = bootstrapType("uint", (*uint)(nil), 3) + tFloat = bootstrapType("float", (*float64)(nil), 4) + tBytes = bootstrapType("bytes", (*[]byte)(nil), 5) + tString = bootstrapType("string", (*string)(nil), 6) + tComplex = bootstrapType("complex", (*complex128)(nil), 7) + tInterface = bootstrapType("interface", (*interface{})(nil), 8) + // Reserve some Ids for compatible expansion + tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil), 9) + tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil), 10) + tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil), 11) + tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil), 12) + tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil), 13) + tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil), 14) + tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil), 15) +) + +// Predefined because it's needed by the Decoder +var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) + +func init() { + // Some magic numbers to make sure there are no surprises. + checkId(16, tWireType) + checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id) + + builtinIdToType = make(map[typeId]gobType) + for k, v := range idToType { + builtinIdToType[k] = v + } + + // Move the id space upwards to allow for growth in the predefined world + // without breaking existing files. + if nextId > firstUserId { + panic(fmt.Sprintln("nextId too large:", nextId)) + } + nextId = firstUserId + registerBasics() + wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) +} + +// Array type +type arrayType struct { + CommonType + Elem typeId + Len int +} + +func newArrayType(name string) *arrayType { + a := &arrayType{CommonType{Name: name}, 0, 0} + return a +} + +func (a *arrayType) init(elem gobType, len int) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(a) + a.Elem = elem.id() + a.Len = len +} + +func (a *arrayType) safeString(seen map[typeId]bool) string { + if seen[a.Id] { + return a.Name + } + seen[a.Id] = true + return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen)) +} + +func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } + +// GobEncoder type (something that implements the GobEncoder interface) +type gobEncoderType struct { + CommonType +} + +func newGobEncoderType(name string) *gobEncoderType { + g := &gobEncoderType{CommonType{Name: name}} + setTypeId(g) + return g +} + +func (g *gobEncoderType) safeString(seen map[typeId]bool) string { + return g.Name +} + +func (g *gobEncoderType) string() string { return g.Name } + +// Map type +type mapType struct { + CommonType + Key typeId + Elem typeId +} + +func newMapType(name string) *mapType { + m := &mapType{CommonType{Name: name}, 0, 0} + return m +} + +func (m *mapType) init(key, elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(m) + m.Key = key.id() + m.Elem = elem.id() +} + +func (m *mapType) safeString(seen map[typeId]bool) string { + if seen[m.Id] { + return m.Name + } + seen[m.Id] = true + key := m.Key.gobType().safeString(seen) + elem := m.Elem.gobType().safeString(seen) + return fmt.Sprintf("map[%s]%s", key, elem) +} + +func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) } + +// Slice type +type sliceType struct { + CommonType + Elem typeId +} + +func newSliceType(name string) *sliceType { + s := &sliceType{CommonType{Name: name}, 0} + return s +} + +func (s *sliceType) init(elem gobType) { + // Set our type id before evaluating the element's, in case it's our own. + setTypeId(s) + s.Elem = elem.id() +} + +func (s *sliceType) safeString(seen map[typeId]bool) string { + if seen[s.Id] { + return s.Name + } + seen[s.Id] = true + return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen)) +} + +func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) } + +// Struct type +type fieldType struct { + Name string + Id typeId +} + +type structType struct { + CommonType + Field []*fieldType +} + +func (s *structType) safeString(seen map[typeId]bool) string { + if s == nil { + return "<nil>" + } + if _, ok := seen[s.Id]; ok { + return s.Name + } + seen[s.Id] = true + str := s.Name + " = struct { " + for _, f := range s.Field { + str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen)) + } + str += "}" + return str +} + +func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) } + +func newStructType(name string) *structType { + s := &structType{CommonType{Name: name}, nil} + // For historical reasons we set the id here rather than init. + // See the comment in newTypeObject for details. + setTypeId(s) + return s +} + +// newTypeObject allocates a gobType for the reflection type rt. +// Unless ut represents a GobEncoder, rt should be the base type +// of ut. +// This is only called from the encoding side. The decoding side +// works through typeIds and userTypeInfos alone. +func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { + // Does this type implement GobEncoder? + if ut.isGobEncoder { + return newGobEncoderType(name), nil + } + var err os.Error + var type0, type1 gobType + defer func() { + if err != nil { + types[rt] = nil, false + } + }() + // Install the top-level type before the subtypes (e.g. struct before + // fields) so recursive types can be constructed safely. + switch t := rt; t.Kind() { + // All basic types are easy: they are predefined. + case reflect.Bool: + return tBool.gobType(), nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return tInt.gobType(), nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return tUint.gobType(), nil + + case reflect.Float32, reflect.Float64: + return tFloat.gobType(), nil + + case reflect.Complex64, reflect.Complex128: + return tComplex.gobType(), nil + + case reflect.String: + return tString.gobType(), nil + + case reflect.Interface: + return tInterface.gobType(), nil + + case reflect.Array: + at := newArrayType(name) + types[rt] = at + type0, err = getBaseType("", t.Elem()) + if err != nil { + return nil, err + } + // Historical aside: + // For arrays, maps, and slices, we set the type id after the elements + // are constructed. This is to retain the order of type id allocation after + // a fix made to handle recursive types, which changed the order in + // which types are built. Delaying the setting in this way preserves + // type ids while allowing recursive types to be described. Structs, + // done below, were already handling recursion correctly so they + // assign the top-level id before those of the field. + at.init(type0, t.Len()) + return at, nil + + case reflect.Map: + mt := newMapType(name) + types[rt] = mt + type0, err = getBaseType("", t.Key()) + if err != nil { + return nil, err + } + type1, err = getBaseType("", t.Elem()) + if err != nil { + return nil, err + } + mt.init(type0, type1) + return mt, nil + + case reflect.Slice: + // []byte == []uint8 is a special case + if t.Elem().Kind() == reflect.Uint8 { + return tBytes.gobType(), nil + } + st := newSliceType(name) + types[rt] = st + type0, err = getBaseType(t.Elem().Name(), t.Elem()) + if err != nil { + return nil, err + } + st.init(type0) + return st, nil + + case reflect.Struct: + st := newStructType(name) + types[rt] = st + idToType[st.id()] = st + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !isExported(f.Name) { + continue + } + typ := userType(f.Type).base + tname := typ.Name() + if tname == "" { + t := userType(f.Type).base + tname = t.String() + } + gt, err := getBaseType(tname, f.Type) + if err != nil { + return nil, err + } + st.Field = append(st.Field, &fieldType{f.Name, gt.id()}) + } + return st, nil + + default: + return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()) + } + return nil, nil +} + +// isExported reports whether this is an exported - upper case - name. +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// getBaseType returns the Gob type describing the given reflect.Type's base type. +// typeLock must be held. +func getBaseType(name string, rt reflect.Type) (gobType, os.Error) { + ut := userType(rt) + return getType(name, ut, ut.base) +} + +// getType returns the Gob type describing the given reflect.Type. +// Should be called only when handling GobEncoders/Decoders, +// which may be pointers. All other types are handled through the +// base type, never a pointer. +// typeLock must be held. +func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { + typ, present := types[rt] + if present { + return typ, nil + } + typ, err := newTypeObject(name, ut, rt) + if err == nil { + types[rt] = typ + } + return typ, err +} + +func checkId(want, got typeId) { + if want != got { + fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want)) + panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string()) + } +} + +// used for building the basic types; called only from init(). the incoming +// interface always refers to a pointer. +func bootstrapType(name string, e interface{}, expect typeId) typeId { + rt := reflect.Typeof(e).Elem() + _, present := types[rt] + if present { + panic("bootstrap type already present: " + name + ", " + rt.String()) + } + typ := &CommonType{Name: name} + types[rt] = typ + setTypeId(typ) + checkId(expect, nextId) + userType(rt) // might as well cache it now + return nextId +} + +// Representation of the information we send and receive about this type. +// Each value we send is preceded by its type definition: an encoded int. +// However, the very first time we send the value, we first send the pair +// (-id, wireType). +// For bootstrapping purposes, we assume that the recipient knows how +// to decode a wireType; it is exactly the wireType struct here, interpreted +// using the gob rules for sending a structure, except that we assume the +// ids for wireType and structType etc. are known. The relevant pieces +// are built in encode.go's init() function. +// To maintain binary compatibility, if you extend this type, always put +// the new fields last. +type wireType struct { + ArrayT *arrayType + SliceT *sliceType + StructT *structType + MapT *mapType + GobEncoderT *gobEncoderType +} + +func (w *wireType) string() string { + const unknown = "unknown type" + if w == nil { + return unknown + } + switch { + case w.ArrayT != nil: + return w.ArrayT.Name + case w.SliceT != nil: + return w.SliceT.Name + case w.StructT != nil: + return w.StructT.Name + case w.MapT != nil: + return w.MapT.Name + case w.GobEncoderT != nil: + return w.GobEncoderT.Name + } + return unknown +} + +type typeInfo struct { + id typeId + encoder *encEngine + wire *wireType +} + +var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock + +// typeLock must be held. +func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) { + rt := ut.base + if ut.isGobEncoder { + // We want the user type, not the base type. + rt = ut.user + } + info, ok := typeInfoMap[rt] + if ok { + return info, nil + } + info = new(typeInfo) + gt, err := getBaseType(rt.Name(), rt) + if err != nil { + return nil, err + } + info.id = gt.id() + + if ut.isGobEncoder { + userType, err := getType(rt.Name(), ut, rt) + if err != nil { + return nil, err + } + info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)} + typeInfoMap[ut.user] = info + return info, nil + } + + t := info.id.gobType() + switch typ := rt; typ.Kind() { + case reflect.Array: + info.wire = &wireType{ArrayT: t.(*arrayType)} + case reflect.Map: + info.wire = &wireType{MapT: t.(*mapType)} + case reflect.Slice: + // []byte == []uint8 is a special case handled separately + if typ.Elem().Kind() != reflect.Uint8 { + info.wire = &wireType{SliceT: t.(*sliceType)} + } + case reflect.Struct: + info.wire = &wireType{StructT: t.(*structType)} + } + typeInfoMap[rt] = info + return info, nil +} + +// Called only when a panic is acceptable and unexpected. +func mustGetTypeInfo(rt reflect.Type) *typeInfo { + t, err := getTypeInfo(userType(rt)) + if err != nil { + panic("getTypeInfo: " + err.String()) + } + return t +} + +// GobEncoder is the interface describing data that provides its own +// representation for encoding values for transmission to a GobDecoder. +// A type that implements GobEncoder and GobDecoder has complete +// control over the representation of its data and may therefore +// contain things such as private fields, channels, and functions, +// which are not usually transmissable in gob streams. +// +// Note: Since gobs can be stored permanently, It is good design +// to guarantee the encoding used by a GobEncoder is stable as the +// software evolves. For instance, it might make sense for GobEncode +// to include a version number in the encoding. +type GobEncoder interface { + // GobEncode returns a byte slice representing the encoding of the + // receiver for transmission to a GobDecoder, usually of the same + // concrete type. + GobEncode() ([]byte, os.Error) +} + +// GobDecoder is the interface describing data that provides its own +// routine for decoding transmitted values sent by a GobEncoder. +type GobDecoder interface { + // GobDecode overwrites the receiver, which must be a pointer, + // with the value represented by the byte slice, which was written + // by GobEncode, usually for the same concrete type. + GobDecode([]byte) os.Error +} + +var ( + nameToConcreteType = make(map[string]reflect.Type) + concreteTypeToName = make(map[reflect.Type]string) +) + +// RegisterName is like Register but uses the provided name rather than the +// type's default. +func RegisterName(name string, value interface{}) { + if name == "" { + // reserved for nil + panic("attempt to register empty name") + } + base := userType(reflect.Typeof(value)).base + // Check for incompatible duplicates. + if t, ok := nameToConcreteType[name]; ok && t != base { + panic("gob: registering duplicate types for " + name) + } + if n, ok := concreteTypeToName[base]; ok && n != name { + panic("gob: registering duplicate names for " + base.String()) + } + // Store the name and type provided by the user.... + nameToConcreteType[name] = reflect.Typeof(value) + // but the flattened type in the type table, since that's what decode needs. + concreteTypeToName[base] = name +} + +// Register records a type, identified by a value for that type, under its +// internal type name. That name will identify the concrete type of a value +// sent or received as an interface variable. Only types that will be +// transferred as implementations of interface values need to be registered. +// Expecting to be used only during initialization, it panics if the mapping +// between types and names is not a bijection. +func Register(value interface{}) { + // Default to printed representation for unnamed types + rt := reflect.Typeof(value) + name := rt.String() + + // But for named types (or pointers to them), qualify with import path. + // Dereference one pointer looking for a named type. + star := "" + if rt.Name() == "" { + if pt := rt; pt.Kind() == reflect.Ptr { + star = "*" + rt = pt + } + } + if rt.Name() != "" { + if rt.PkgPath() == "" { + name = star + rt.Name() + } else { + name = star + rt.PkgPath() + "." + rt.Name() + } + } + + RegisterName(name, value) +} + +func registerBasics() { + Register(int(0)) + Register(int8(0)) + Register(int16(0)) + Register(int32(0)) + Register(int64(0)) + Register(uint(0)) + Register(uint8(0)) + Register(uint16(0)) + Register(uint32(0)) + Register(uint64(0)) + Register(float32(0)) + Register(float64(0)) + Register(complex64(0i)) + Register(complex128(0i)) + Register(false) + Register("") + Register([]byte(nil)) +} diff --git a/src/cmd/gofix/typecheck.go b/src/cmd/gofix/typecheck.go new file mode 100644 index 000000000..d565e7b4b --- /dev/null +++ b/src/cmd/gofix/typecheck.go @@ -0,0 +1,579 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "os" + "reflect" + "strings" +) + +// Partial type checker. +// +// The fact that it is partial is very important: the input is +// an AST and a description of some type information to +// assume about one or more packages, but not all the +// packages that the program imports. The checker is +// expected to do as much as it can with what it has been +// given. There is not enough information supplied to do +// a full type check, but the type checker is expected to +// apply information that can be derived from variable +// declarations, function and method returns, and type switches +// as far as it can, so that the caller can still tell the types +// of expression relevant to a particular fix. +// +// TODO(rsc,gri): Replace with go/typechecker. +// Doing that could be an interesting test case for go/typechecker: +// the constraints about working with partial information will +// likely exercise it in interesting ways. The ideal interface would +// be to pass typecheck a map from importpath to package API text +// (Go source code), but for now we use data structures (TypeConfig, Type). +// +// The strings mostly use gofmt form. +// +// A Field or FieldList has as its type a comma-separated list +// of the types of the fields. For example, the field list +// x, y, z int +// has type "int, int, int". + +// The prefix "type " is the type of a type. +// For example, given +// var x int +// type T int +// x's type is "int" but T's type is "type int". +// mkType inserts the "type " prefix. +// getType removes it. +// isType tests for it. + +func mkType(t string) string { + return "type " + t +} + +func getType(t string) string { + if !isType(t) { + return "" + } + return t[len("type "):] +} + +func isType(t string) bool { + return strings.HasPrefix(t, "type ") +} + +// TypeConfig describes the universe of relevant types. +// For ease of creation, the types are all referred to by string +// name (e.g., "reflect.Value"). TypeByName is the only place +// where the strings are resolved. + +type TypeConfig struct { + Type map[string]*Type + Var map[string]string + Func map[string]string +} + +// typeof returns the type of the given name, which may be of +// the form "x" or "p.X". +func (cfg *TypeConfig) typeof(name string) string { + if cfg.Var != nil { + if t := cfg.Var[name]; t != "" { + return t + } + } + if cfg.Func != nil { + if t := cfg.Func[name]; t != "" { + return "func()" + t + } + } + return "" +} + +// Type describes the Fields and Methods of a type. +// If the field or method cannot be found there, it is next +// looked for in the Embed list. +type Type struct { + Field map[string]string // map field name to type + Method map[string]string // map method name to comma-separated return types + Embed []string // list of types this type embeds (for extra methods) +} + +// dot returns the type of "typ.name", making its decision +// using the type information in cfg. +func (typ *Type) dot(cfg *TypeConfig, name string) string { + if typ.Field != nil { + if t := typ.Field[name]; t != "" { + return t + } + } + if typ.Method != nil { + if t := typ.Method[name]; t != "" { + return t + } + } + + for _, e := range typ.Embed { + etyp := cfg.Type[e] + if etyp != nil { + if t := etyp.dot(cfg, name); t != "" { + return t + } + } + } + + return "" +} + +// typecheck type checks the AST f assuming the information in cfg. +// It returns a map from AST nodes to type information in gofmt string form. +func typecheck(cfg *TypeConfig, f *ast.File) map[interface{}]string { + typeof := make(map[interface{}]string) + + // gather function declarations + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + typecheck1(cfg, fn.Type, typeof) + t := typeof[fn.Type] + if fn.Recv != nil { + // The receiver must be a type. + rcvr := typeof[fn.Recv] + if !isType(rcvr) { + if len(fn.Recv.List) != 1 { + continue + } + rcvr = mkType(gofmt(fn.Recv.List[0].Type)) + typeof[fn.Recv.List[0].Type] = rcvr + } + rcvr = getType(rcvr) + if rcvr != "" && rcvr[0] == '*' { + rcvr = rcvr[1:] + } + typeof[rcvr+"."+fn.Name.Name] = t + } else { + if isType(t) { + t = getType(t) + } else { + t = gofmt(fn.Type) + } + typeof[fn.Name] = t + + // Record typeof[fn.Name.Obj] for future references to fn.Name. + typeof[fn.Name.Obj] = t + } + } + + typecheck1(cfg, f, typeof) + return typeof +} + +func makeExprList(a []*ast.Ident) []ast.Expr { + var b []ast.Expr + for _, x := range a { + b = append(b, x) + } + return b +} + +// Typecheck1 is the recursive form of typecheck. +// It is like typecheck but adds to the information in typeof +// instead of allocating a new map. +func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { + // set sets the type of n to typ. + // If isDecl is true, n is being declared. + set := func(n ast.Expr, typ string, isDecl bool) { + if typeof[n] != "" || typ == "" { + return + } + typeof[n] = typ + + // If we obtained typ from the declaration of x + // propagate the type to all the uses. + // The !isDecl case is a cheat here, but it makes + // up in some cases for not paying attention to + // struct fields. The real type checker will be + // more accurate so we won't need the cheat. + if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") { + typeof[id.Obj] = typ + } + } + + // Type-check an assignment lhs = rhs. + // If isDecl is true, this is := so we can update + // the types of the objects that lhs refers to. + typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) { + if len(lhs) > 1 && len(rhs) == 1 { + if _, ok := rhs[0].(*ast.CallExpr); ok { + t := split(typeof[rhs[0]]) + // Lists should have same length but may not; pair what can be paired. + for i := 0; i < len(lhs) && i < len(t); i++ { + set(lhs[i], t[i], isDecl) + } + return + } + } + if len(lhs) == 1 && len(rhs) == 2 { + // x = y, ok + rhs = rhs[:1] + } else if len(lhs) == 2 && len(rhs) == 1 { + // x, ok = y + lhs = lhs[:1] + } + + // Match as much as we can. + for i := 0; i < len(lhs) && i < len(rhs); i++ { + x, y := lhs[i], rhs[i] + if typeof[y] != "" { + set(x, typeof[y], isDecl) + } else { + set(y, typeof[x], false) + } + } + } + + // The main type check is a recursive algorithm implemented + // by walkBeforeAfter(n, before, after). + // Most of it is bottom-up, but in a few places we need + // to know the type of the function we are checking. + // The before function records that information on + // the curfn stack. + var curfn []*ast.FuncType + + before := func(n interface{}) { + // push function type on stack + switch n := n.(type) { + case *ast.FuncDecl: + curfn = append(curfn, n.Type) + case *ast.FuncLit: + curfn = append(curfn, n.Type) + } + } + + // After is the real type checker. + after := func(n interface{}) { + if n == nil { + return + } + if false && reflect.Typeof(n).Kind() == reflect.Ptr { // debugging trace + defer func() { + if t := typeof[n]; t != "" { + pos := fset.Position(n.(ast.Node).Pos()) + fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos.String(), gofmt(n), t) + } + }() + } + + switch n := n.(type) { + case *ast.FuncDecl, *ast.FuncLit: + // pop function type off stack + curfn = curfn[:len(curfn)-1] + + case *ast.FuncType: + typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results]))) + + case *ast.FieldList: + // Field list is concatenation of sub-lists. + t := "" + for _, field := range n.List { + if t != "" { + t += ", " + } + t += typeof[field] + } + typeof[n] = t + + case *ast.Field: + // Field is one instance of the type per name. + all := "" + t := typeof[n.Type] + if !isType(t) { + // Create a type, because it is typically *T or *p.T + // and we might care about that type. + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + if len(n.Names) == 0 { + all = t + } else { + for _, id := range n.Names { + if all != "" { + all += ", " + } + all += t + typeof[id.Obj] = t + typeof[id] = t + } + } + typeof[n] = all + + case *ast.ValueSpec: + // var declaration. Use type if present. + if n.Type != nil { + t := typeof[n.Type] + if !isType(t) { + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + for _, id := range n.Names { + set(id, t, true) + } + } + // Now treat same as assignment. + typecheckAssign(makeExprList(n.Names), n.Values, true) + + case *ast.AssignStmt: + typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE) + + case *ast.Ident: + // Identifier can take its type from underlying object. + if t := typeof[n.Obj]; t != "" { + typeof[n] = t + } + + case *ast.SelectorExpr: + // Field or method. + name := n.Sel.Name + if t := typeof[n.X]; t != "" { + if strings.HasPrefix(t, "*") { + t = t[1:] // implicit * + } + if typ := cfg.Type[t]; typ != nil { + if t := typ.dot(cfg, name); t != "" { + typeof[n] = t + return + } + } + tt := typeof[t+"."+name] + if isType(tt) { + typeof[n] = getType(tt) + return + } + } + // Package selector. + if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil { + str := x.Name + "." + name + if cfg.Type[str] != nil { + typeof[n] = mkType(str) + return + } + if t := cfg.typeof(x.Name + "." + name); t != "" { + typeof[n] = t + return + } + } + + case *ast.CallExpr: + // make(T) has type T. + if isTopName(n.Fun, "make") && len(n.Args) >= 1 { + typeof[n] = gofmt(n.Args[0]) + return + } + // Otherwise, use type of function to determine arguments. + t := typeof[n.Fun] + in, out := splitFunc(t) + if in == nil && out == nil { + return + } + typeof[n] = join(out) + for i, arg := range n.Args { + if i >= len(in) { + break + } + if typeof[arg] == "" { + typeof[arg] = in[i] + } + } + + case *ast.TypeAssertExpr: + // x.(type) has type of x. + if n.Type == nil { + typeof[n] = typeof[n.X] + return + } + // x.(T) has type T. + if t := typeof[n.Type]; isType(t) { + typeof[n] = getType(t) + } + + case *ast.SliceExpr: + // x[i:j] has type of x. + typeof[n] = typeof[n.X] + + case *ast.IndexExpr: + // x[i] has key type of x's type. + t := typeof[n.X] + if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") { + // Lazy: assume there are no nested [] in the array + // length or map key type. + if i := strings.Index(t, "]"); i >= 0 { + typeof[n] = t[i+1:] + } + } + + case *ast.StarExpr: + // *x for x of type *T has type T when x is an expr. + // We don't use the result when *x is a type, but + // compute it anyway. + t := typeof[n.X] + if isType(t) { + typeof[n] = "type *" + getType(t) + } else if strings.HasPrefix(t, "*") { + typeof[n] = t[len("*"):] + } + + case *ast.UnaryExpr: + // &x for x of type T has type *T. + t := typeof[n.X] + if t != "" && n.Op == token.AND { + typeof[n] = "&" + t + } + + case *ast.CompositeLit: + // T{...} has type T. + typeof[n] = gofmt(n.Type) + + case *ast.ParenExpr: + // (x) has type of x. + typeof[n] = typeof[n.X] + + case *ast.TypeSwitchStmt: + // Type of variable changes for each case in type switch, + // but go/parser generates just one variable. + // Repeat type check for each case with more precise + // type information. + as, ok := n.Assign.(*ast.AssignStmt) + if !ok { + return + } + varx, ok := as.Lhs[0].(*ast.Ident) + if !ok { + return + } + t := typeof[varx] + for _, cas := range n.Body.List { + cas := cas.(*ast.CaseClause) + if len(cas.List) == 1 { + // Variable has specific type only when there is + // exactly one type in the case list. + if tt := typeof[cas.List[0]]; isType(tt) { + tt = getType(tt) + typeof[varx] = tt + typeof[varx.Obj] = tt + typecheck1(cfg, cas.Body, typeof) + } + } + } + // Restore t. + typeof[varx] = t + typeof[varx.Obj] = t + + case *ast.ReturnStmt: + if len(curfn) == 0 { + // Probably can't happen. + return + } + f := curfn[len(curfn)-1] + res := n.Results + if f.Results != nil { + t := split(typeof[f.Results]) + for i := 0; i < len(res) && i < len(t); i++ { + set(res[i], t[i], false) + } + } + } + } + walkBeforeAfter(f, before, after) +} + +// Convert between function type strings and lists of types. +// Using strings makes this a little harder, but it makes +// a lot of the rest of the code easier. This will all go away +// when we can use go/typechecker directly. + +// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"]. +func splitFunc(s string) (in, out []string) { + if !strings.HasPrefix(s, "func(") { + return nil, nil + } + + i := len("func(") // index of beginning of 'in' arguments + nparen := 0 + for j := i; j < len(s); j++ { + switch s[j] { + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // found end of parameter list + out := strings.TrimSpace(s[j+1:]) + if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' { + out = out[1 : len(out)-1] + } + return split(s[i:j]), split(out) + } + } + } + return nil, nil +} + +// joinFunc is the inverse of splitFunc. +func joinFunc(in, out []string) string { + outs := "" + if len(out) == 1 { + outs = " " + out[0] + } else if len(out) > 1 { + outs = " (" + join(out) + ")" + } + return "func(" + join(in) + ")" + outs +} + +// split splits "int, float" into ["int", "float"] and splits "" into []. +func split(s string) []string { + out := []string{} + i := 0 // current type being scanned is s[i:j]. + nparen := 0 + for j := 0; j < len(s); j++ { + switch s[j] { + case ' ': + if i == j { + i++ + } + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // probably can't happen + return nil + } + case ',': + if nparen == 0 { + if i < j { + out = append(out, s[i:j]) + } + i = j + 1 + } + } + } + if nparen != 0 { + // probably can't happen + return nil + } + if i < len(s) { + out = append(out, s[i:]) + } + return out +} + +// join is the inverse of split. +func join(x []string) string { + return strings.Join(x, ", ") +} diff --git a/src/cmd/gofmt/Makefile b/src/cmd/gofmt/Makefile index 5f2f454e8..dc5b060e6 100644 --- a/src/cmd/gofmt/Makefile +++ b/src/cmd/gofmt/Makefile @@ -15,5 +15,5 @@ include ../../Make.cmd test: $(TARG) ./test.sh -smoketest: $(TARG) - (cd testdata; ./test.sh) +testshort: + gotest -test.short diff --git a/src/cmd/gofmt/doc.go b/src/cmd/gofmt/doc.go index 2d2c9ae61..e44030eee 100644 --- a/src/cmd/gofmt/doc.go +++ b/src/cmd/gofmt/doc.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. /* - Gofmt formats Go programs. Without an explicit path, it processes the standard input. Given a file, @@ -16,14 +15,16 @@ Usage: The flags are: -l - just list files whose formatting differs from gofmt's; generate no other output - unless -w is also set. + just list files whose formatting differs from gofmt's; + generate no other output unless -w is also set. -r rule apply the rewrite rule to the source before reformatting. -s try to simplify code (after applying the rewrite rule, if any). -w if set, overwrite each input file with its output. + -comments=true + print comments; if false, all comments are elided from the output. -spaces align with spaces instead of tabs. -tabindent @@ -31,15 +32,6 @@ The flags are: -tabwidth=8 tab width in spaces. -Debugging flags: - - -trace - print parse trace. - -ast - print AST (before rewrites). - -comments=true - print comments; if false, all comments are elided from the output. - The rewrite rule specified with the -r flag must be a string of the form: pattern -> replacement diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go index 224aee717..ce274aa21 100644 --- a/src/cmd/gofmt/gofmt.go +++ b/src/cmd/gofmt/gofmt.go @@ -13,9 +13,11 @@ import ( "go/printer" "go/scanner" "go/token" + "io" "io/ioutil" "os" "path/filepath" + "runtime/pprof" "strings" ) @@ -27,15 +29,14 @@ var ( rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')") simplifyAST = flag.Bool("s", false, "simplify code") - // debugging support - comments = flag.Bool("comments", true, "print comments") - trace = flag.Bool("trace", false, "print parse trace") - printAST = flag.Bool("ast", false, "print AST (before rewrites)") - // layout control + comments = flag.Bool("comments", true, "print comments") tabWidth = flag.Int("tabwidth", 8, "tab width") tabIndent = flag.Bool("tabindent", true, "indent with tabs independent of -spaces") useSpaces = flag.Bool("spaces", true, "align with spaces instead of tabs") + + // debugging + cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file") ) @@ -66,9 +67,6 @@ func initParserMode() { if *comments { parserMode |= parser.ParseComments } - if *trace { - parserMode |= parser.Trace - } } @@ -89,20 +87,25 @@ func isGoFile(f *os.FileInfo) bool { } -func processFile(f *os.File) os.Error { - src, err := ioutil.ReadAll(f) - if err != nil { - return err +// If in == nil, the source is the contents of the file with the given filename. +func processFile(filename string, in io.Reader, out io.Writer) os.Error { + if in == nil { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + in = f } - file, err := parser.ParseFile(fset, f.Name(), src, parserMode) - + src, err := ioutil.ReadAll(in) if err != nil { return err } - if *printAST { - ast.Print(file) + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err != nil { + return err } if rewrite != nil { @@ -123,10 +126,10 @@ func processFile(f *os.File) os.Error { if !bytes.Equal(src, res) { // formatting has changed if *list { - fmt.Fprintln(os.Stdout, f.Name()) + fmt.Fprintln(out, filename) } if *write { - err = ioutil.WriteFile(f.Name(), res, 0) + err = ioutil.WriteFile(filename, res, 0) if err != nil { return err } @@ -134,23 +137,13 @@ func processFile(f *os.File) os.Error { } if !*list && !*write { - _, err = os.Stdout.Write(res) + _, err = out.Write(res) } return err } -func processFileByName(filename string) os.Error { - file, err := os.Open(filename, os.O_RDONLY, 0) - if err != nil { - return err - } - defer file.Close() - return processFile(file) -} - - type fileVisitor chan os.Error func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { @@ -161,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { if isGoFile(f) { v <- nil // synchronize error handler - if err := processFileByName(path); err != nil { + if err := processFile(path, nil, os.Stdout); err != nil { v <- err } } @@ -169,30 +162,47 @@ func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { func walkDir(path string) { - // start an error handler - done := make(chan bool) v := make(fileVisitor) go func() { - for err := range v { - if err != nil { - report(err) - } - } - done <- true + filepath.Walk(path, v, v) + close(v) }() - // walk the tree - filepath.Walk(path, v, v) - close(v) // terminate error handler loop - <-done // wait for all errors to be reported + for err := range v { + if err != nil { + report(err) + } + } } func main() { + // call gofmtMain in a separate function + // so that it can use defer and have them + // run before the exit. + gofmtMain() + os.Exit(exitCode) +} + + +func gofmtMain() { flag.Usage = usage flag.Parse() if *tabWidth < 0 { fmt.Fprintf(os.Stderr, "negative tabwidth %d\n", *tabWidth) - os.Exit(2) + exitCode = 2 + return + } + + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + fmt.Fprintf(os.Stderr, "creating cpu profile: %s\n", err) + exitCode = 2 + return + } + defer f.Close() + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() } initParserMode() @@ -200,9 +210,10 @@ func main() { initRewrite() if flag.NArg() == 0 { - if err := processFile(os.Stdin); err != nil { + if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil { report(err) } + return } for i := 0; i < flag.NArg(); i++ { @@ -211,13 +222,11 @@ func main() { case err != nil: report(err) case dir.IsRegular(): - if err := processFileByName(path); err != nil { + if err := processFile(path, nil, os.Stdout); err != nil { report(err) } case dir.IsDirectory(): walkDir(path) } } - - os.Exit(exitCode) } diff --git a/src/cmd/gofmt/gofmt_test.go b/src/cmd/gofmt/gofmt_test.go new file mode 100644 index 000000000..4ec94e293 --- /dev/null +++ b/src/cmd/gofmt/gofmt_test.go @@ -0,0 +1,81 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "io/ioutil" + "path/filepath" + "strings" + "testing" +) + + +func runTest(t *testing.T, dirname, in, out, flags string) { + in = filepath.Join(dirname, in) + out = filepath.Join(dirname, out) + + // process flags + *simplifyAST = false + *rewriteRule = "" + for _, flag := range strings.Split(flags, " ", -1) { + elts := strings.Split(flag, "=", 2) + name := elts[0] + value := "" + if len(elts) == 2 { + value = elts[1] + } + switch name { + case "": + // no flags + case "-r": + *rewriteRule = value + case "-s": + *simplifyAST = true + default: + t.Errorf("unrecognized flag name: %s", name) + } + } + + initParserMode() + initPrinterMode() + initRewrite() + + var buf bytes.Buffer + err := processFile(in, nil, &buf) + if err != nil { + t.Error(err) + return + } + + expected, err := ioutil.ReadFile(out) + if err != nil { + t.Error(err) + return + } + + if got := buf.Bytes(); bytes.Compare(got, expected) != 0 { + t.Errorf("(gofmt %s) != %s (see %s.gofmt)", in, out, in) + ioutil.WriteFile(in+".gofmt", got, 0666) + } +} + + +// TODO(gri) Add more test cases! +var tests = []struct { + dirname, in, out, flags string +}{ + {".", "gofmt.go", "gofmt.go", ""}, + {".", "gofmt_test.go", "gofmt_test.go", ""}, + {"testdata", "composites.input", "composites.golden", "-s"}, + {"testdata", "rewrite1.input", "rewrite1.golden", "-r=Foo->Bar"}, +} + + +func TestRewrite(t *testing.T) { + for _, test := range tests { + runTest(t, test.dirname, test.in, test.out, test.flags) + } +} diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index fbcd46aa2..93643dced 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -46,6 +46,16 @@ func parseExpr(s string, what string) ast.Expr { } +// Keep this function for debugging. +/* +func dump(msg string, val reflect.Value) { + fmt.Printf("%s:\n", msg) + ast.Print(fset, val.Interface()) + fmt.Println() +} +*/ + + // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { m := make(map[string]reflect.Value) @@ -54,7 +64,7 @@ func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { var f func(val reflect.Value) reflect.Value // f is recursive f = func(val reflect.Value) reflect.Value { for k := range m { - m[k] = nil, false + m[k] = reflect.Value{}, false } val = apply(f, val) if match(m, pat, val) { @@ -78,28 +88,45 @@ func setValue(x, y reflect.Value) { panic(x) } }() - x.SetValue(y) + x.Set(y) } +// Values/types for special cases. +var ( + objectPtrNil = reflect.NewValue((*ast.Object)(nil)) + + identType = reflect.Typeof((*ast.Ident)(nil)) + objectPtrType = reflect.Typeof((*ast.Object)(nil)) + positionType = reflect.Typeof(token.NoPos) +) + + // apply replaces each AST field x in val with f(x), returning val. // To avoid extra conversions, f operates on the reflect.Value form. func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { - if val == nil { - return nil + if !val.IsValid() { + return reflect.Value{} + } + + // *ast.Objects introduce cycles and are likely incorrect after + // rewrite; don't follow them but replace with nil instead + if val.Type() == objectPtrType { + return objectPtrNil } - switch v := reflect.Indirect(val).(type) { - case *reflect.SliceValue: + + switch v := reflect.Indirect(val); v.Kind() { + case reflect.Slice: for i := 0; i < v.Len(); i++ { - e := v.Elem(i) + e := v.Index(i) setValue(e, f(e)) } - case *reflect.StructValue: + case reflect.Struct: for i := 0; i < v.NumField(); i++ { e := v.Field(i) setValue(e, f(e)) } - case *reflect.InterfaceValue: + case reflect.Interface: e := v.Elem() setValue(v, f(e)) } @@ -107,10 +134,6 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value } -var positionType = reflect.Typeof(token.NoPos) -var identType = reflect.Typeof((*ast.Ident)(nil)) - - func isWildcard(s string) bool { rune, size := utf8.DecodeRuneInString(s) return size == len(s) && unicode.IsLower(rune) @@ -124,9 +147,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Wildcard matches any expression. If it appears multiple // times in the pattern, it must match the same expression // each time. - if m != nil && pattern != nil && pattern.Type() == identType { + if m != nil && pattern.IsValid() && pattern.Type() == identType { name := pattern.Interface().(*ast.Ident).Name - if isWildcard(name) && val != nil { + if isWildcard(name) && val.IsValid() { // wildcards only match expressions if _, ok := val.Interface().(ast.Expr); ok { if old, ok := m[name]; ok { @@ -139,8 +162,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { } // Otherwise, pattern and val must match recursively. - if pattern == nil || val == nil { - return pattern == nil && val == nil + if !pattern.IsValid() || !val.IsValid() { + return !pattern.IsValid() && !val.IsValid() } if pattern.Type() != val.Type() { return false @@ -148,9 +171,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Special cases. switch pattern.Type() { - case positionType: - // token positions don't need to match - return true case identType: // For identifiers, only the names need to match // (and none of the other *ast.Object information). @@ -159,29 +179,30 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { p := pattern.Interface().(*ast.Ident) v := val.Interface().(*ast.Ident) return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name + case objectPtrType, positionType: + // object pointers and token positions don't need to match + return true } p := reflect.Indirect(pattern) v := reflect.Indirect(val) - if p == nil || v == nil { - return p == nil && v == nil + if !p.IsValid() || !v.IsValid() { + return !p.IsValid() && !v.IsValid() } - switch p := p.(type) { - case *reflect.SliceValue: - v := v.(*reflect.SliceValue) + switch p.Kind() { + case reflect.Slice: if p.Len() != v.Len() { return false } for i := 0; i < p.Len(); i++ { - if !match(m, p.Elem(i), v.Elem(i)) { + if !match(m, p.Index(i), v.Index(i)) { return false } } return true - case *reflect.StructValue: - v := v.(*reflect.StructValue) + case reflect.Struct: if p.NumField() != v.NumField() { return false } @@ -192,8 +213,7 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { } return true - case *reflect.InterfaceValue: - v := v.(*reflect.InterfaceValue) + case reflect.Interface: return match(m, p.Elem(), v.Elem()) } @@ -207,8 +227,8 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // if m == nil, subst returns a copy of pattern and doesn't change the line // number information. func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value { - if pattern == nil { - return nil + if !pattern.IsValid() { + return reflect.Value{} } // Wildcard gets replaced with map value. @@ -216,12 +236,12 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) name := pattern.Interface().(*ast.Ident).Name if isWildcard(name) { if old, ok := m[name]; ok { - return subst(nil, old, nil) + return subst(nil, old, reflect.Value{}) } } } - if pos != nil && pattern.Type() == positionType { + if pos.IsValid() && pattern.Type() == positionType { // use new position only if old position was valid in the first place if old := pattern.Interface().(token.Pos); !old.IsValid() { return pattern @@ -230,29 +250,33 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) } // Otherwise copy. - switch p := pattern.(type) { - case *reflect.SliceValue: - v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len()) + switch p := pattern; p.Kind() { + case reflect.Slice: + v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) for i := 0; i < p.Len(); i++ { - v.Elem(i).SetValue(subst(m, p.Elem(i), pos)) + v.Index(i).Set(subst(m, p.Index(i), pos)) } return v - case *reflect.StructValue: - v := reflect.MakeZero(p.Type()).(*reflect.StructValue) + case reflect.Struct: + v := reflect.Zero(p.Type()) for i := 0; i < p.NumField(); i++ { - v.Field(i).SetValue(subst(m, p.Field(i), pos)) + v.Field(i).Set(subst(m, p.Field(i), pos)) } return v - case *reflect.PtrValue: - v := reflect.MakeZero(p.Type()).(*reflect.PtrValue) - v.PointTo(subst(m, p.Elem(), pos)) + case reflect.Ptr: + v := reflect.Zero(p.Type()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos).Addr()) + } return v - case *reflect.InterfaceValue: - v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue) - v.SetValue(subst(m, p.Elem(), pos)) + case reflect.Interface: + v := reflect.Zero(p.Type()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos)) + } return v } diff --git a/src/cmd/gofmt/testdata/rewrite1.golden b/src/cmd/gofmt/testdata/rewrite1.golden new file mode 100644 index 000000000..3f909ff4a --- /dev/null +++ b/src/cmd/gofmt/testdata/rewrite1.golden @@ -0,0 +1,8 @@ +package main + +type Bar int + +func main() { + var a Bar + println(a) +} diff --git a/src/cmd/gofmt/testdata/rewrite1.input b/src/cmd/gofmt/testdata/rewrite1.input new file mode 100644 index 000000000..1f10e3601 --- /dev/null +++ b/src/cmd/gofmt/testdata/rewrite1.input @@ -0,0 +1,8 @@ +package main + +type Foo int + +func main() { + var a Foo + println(a) +} diff --git a/src/cmd/gofmt/testdata/test.sh b/src/cmd/gofmt/testdata/test.sh deleted file mode 100755 index a1d5d823e..000000000 --- a/src/cmd/gofmt/testdata/test.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2010 The Go Authors. All rights reserved. -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file. - -CMD="../gofmt" -TMP=test_tmp.go -COUNT=0 - - -cleanup() { - rm -f $TMP -} - - -error() { - echo $1 - exit 1 -} - - -count() { - #echo $1 - let COUNT=$COUNT+1 - let M=$COUNT%10 - if [ $M == 0 ]; then - echo -n "." - fi -} - - -test() { - count $1 - - # compare against .golden file - cleanup - $CMD -s $1 > $TMP - cmp -s $TMP $2 - if [ $? != 0 ]; then - diff $TMP $2 - error "Error: simplified $1 does not match $2" - fi - - # make sure .golden is idempotent - cleanup - $CMD -s $2 > $TMP - cmp -s $TMP $2 - if [ $? != 0 ]; then - diff $TMP $2 - error "Error: $2 is not idempotent" - fi -} - - -runtests() { - smoketest=../../../pkg/go/parser/parser.go - test $smoketest $smoketest - test composites.input composites.golden - # add more test cases here -} - - -runtests -cleanup -echo "PASSED ($COUNT tests)" diff --git a/src/cmd/goinstall/Makefile b/src/cmd/goinstall/Makefile index 6ddb32be7..aaf202ee7 100644 --- a/src/cmd/goinstall/Makefile +++ b/src/cmd/goinstall/Makefile @@ -10,5 +10,20 @@ GOFILES=\ main.go\ make.go\ parse.go\ + syslist.go\ + +CLEANFILES+=syslist.go include ../../Make.cmd + +syslist.go: + echo '// Generated automatically by make.' >$@ + echo 'package main' >>$@ + echo 'const goosList = "$(GOOS_LIST)"' >>$@ + echo 'const goarchList = "$(GOARCH_LIST)"' >>$@ + +test: + gotest + +testshort: + gotest -test.short diff --git a/src/cmd/goinstall/doc.go b/src/cmd/goinstall/doc.go index 17cc06969..15845b574 100644 --- a/src/cmd/goinstall/doc.go +++ b/src/cmd/goinstall/doc.go @@ -14,6 +14,7 @@ Usage: Flags and default settings: -a=false install all previously installed packages + -clean=false clean the package directory before installing -dashboard=true tally public packages on godashboard.appspot.com -log=true log installed packages to $GOROOT/goinstall.log for use by -a -u=false update already-downloaded packages diff --git a/src/cmd/goinstall/main.go b/src/cmd/goinstall/main.go index 34441be45..8fec8e312 100644 --- a/src/cmd/goinstall/main.go +++ b/src/cmd/goinstall/main.go @@ -120,7 +120,7 @@ func logPackage(pkg string) { if installedPkgs[pkg] { return } - fout, err := os.Open(logfile, os.O_WRONLY|os.O_APPEND|os.O_CREAT, 0644) + fout, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) if err != nil { fmt.Fprintf(os.Stderr, "%s: %s\n", argv0, err) return diff --git a/src/cmd/goinstall/make.go b/src/cmd/goinstall/make.go index e2d99bb47..ceb119e5a 100644 --- a/src/cmd/goinstall/make.go +++ b/src/cmd/goinstall/make.go @@ -52,12 +52,6 @@ func makeMakefile(dir, pkg string) ([]byte, os.Error) { return nil, err } - if len(dirInfo.cgoFiles) == 0 && len(dirInfo.cFiles) > 0 { - // When using cgo, .c files are compiled with gcc. Without cgo, - // they may be intended for 6c. Just error out for now. - return nil, os.ErrorString("C files found in non-cgo package") - } - cgoFiles := dirInfo.cgoFiles isCgo := make(map[string]bool, len(cgoFiles)) for _, file := range cgoFiles { @@ -67,26 +61,40 @@ func makeMakefile(dir, pkg string) ([]byte, os.Error) { isCgo[file] = true } - oFiles := make([]string, 0, len(dirInfo.cFiles)) - for _, file := range dirInfo.cFiles { + goFiles := make([]string, 0, len(dirInfo.goFiles)) + for _, file := range dirInfo.goFiles { if !safeName(file) { return nil, os.ErrorString("unsafe name: " + file) } - oFiles = append(oFiles, file[:len(file)-2]+".o") + if !isCgo[file] { + goFiles = append(goFiles, file) + } } - goFiles := make([]string, 0, len(dirInfo.goFiles)) - for _, file := range dirInfo.goFiles { + oFiles := make([]string, 0, len(dirInfo.cFiles)+len(dirInfo.sFiles)) + cgoOFiles := make([]string, 0, len(dirInfo.cFiles)) + for _, file := range dirInfo.cFiles { if !safeName(file) { return nil, os.ErrorString("unsafe name: " + file) } - if !isCgo[file] { - goFiles = append(goFiles, file) + // When cgo is in use, C files are compiled with gcc, + // otherwise they're compiled with gc. + if len(cgoFiles) > 0 { + cgoOFiles = append(cgoOFiles, file[:len(file)-2]+".o") + } else { + oFiles = append(oFiles, file[:len(file)-2]+".$O") } } + for _, file := range dirInfo.sFiles { + if !safeName(file) { + return nil, os.ErrorString("unsafe name: " + file) + } + oFiles = append(oFiles, file[:len(file)-2]+".$O") + } + var buf bytes.Buffer - md := makedata{pkg, goFiles, cgoFiles, oFiles} + md := makedata{pkg, goFiles, oFiles, cgoFiles, cgoOFiles} if err := makefileTemplate.Execute(&buf, &md); err != nil { return nil, err } @@ -106,10 +114,11 @@ func safeName(s string) bool { // makedata is the data type for the makefileTemplate. type makedata struct { - Pkg string // package import path - GoFiles []string // list of non-cgo .go files - CgoFiles []string // list of cgo .go files - OFiles []string // list of ofiles for cgo + Pkg string // package import path + GoFiles []string // list of non-cgo .go files + OFiles []string // list of .$O files + CgoFiles []string // list of cgo .go files + CgoOFiles []string // list of cgo .o files, without extension } var makefileTemplate = template.MustParse(` @@ -124,6 +133,13 @@ GOFILES=\ {.end} {.end} +{.section OFiles} +OFILES=\ +{.repeated section OFiles} + {@}\ +{.end} + +{.end} {.section CgoFiles} CGOFILES=\ {.repeated section CgoFiles} @@ -131,9 +147,9 @@ CGOFILES=\ {.end} {.end} -{.section OFiles} +{.section CgoOFiles} CGO_OFILES=\ -{.repeated section OFiles} +{.repeated section CgoOFiles} {@}\ {.end} diff --git a/src/cmd/goinstall/parse.go b/src/cmd/goinstall/parse.go index 014b8fcb2..0e617903c 100644 --- a/src/cmd/goinstall/parse.go +++ b/src/cmd/goinstall/parse.go @@ -14,6 +14,7 @@ import ( "path/filepath" "strconv" "strings" + "runtime" ) @@ -21,6 +22,7 @@ type dirInfo struct { goFiles []string // .go files within dir (including cgoFiles) cgoFiles []string // .go files that import "C" cFiles []string // .c files within dir + sFiles []string // .s files within dir imports []string // All packages imported by goFiles pkgName string // Name of package within dir } @@ -37,7 +39,7 @@ type dirInfo struct { // The imports map keys are package paths imported by listed Go files, // and the values are the Go files importing the respective package paths. func scanDir(dir string, allowMain bool) (info *dirInfo, err os.Error) { - f, err := os.Open(dir, os.O_RDONLY, 0) + f, err := os.Open(dir) if err != nil { return nil, err } @@ -50,6 +52,7 @@ func scanDir(dir string, allowMain bool) (info *dirInfo, err os.Error) { goFiles := make([]string, 0, len(dirs)) cgoFiles := make([]string, 0, len(dirs)) cFiles := make([]string, 0, len(dirs)) + sFiles := make([]string, 0, len(dirs)) importsm := make(map[string]bool) pkgName := "" for i := range dirs { @@ -57,13 +60,25 @@ func scanDir(dir string, allowMain bool) (info *dirInfo, err os.Error) { if strings.HasPrefix(d.Name, "_") || strings.Index(d.Name, ".cgo") != -1 { continue } - if strings.HasSuffix(d.Name, ".c") { - cFiles = append(cFiles, d.Name) + if !goodOSArch(d.Name) { continue } - if !strings.HasSuffix(d.Name, ".go") || strings.HasSuffix(d.Name, "_test.go") { + + switch filepath.Ext(d.Name) { + case ".go": + if strings.HasSuffix(d.Name, "_test.go") { + continue + } + case ".c": + cFiles = append(cFiles, d.Name) + continue + case ".s": + sFiles = append(sFiles, d.Name) + continue + default: continue } + filename := filepath.Join(dir, d.Name) pf, err := parser.ParseFile(fset, filename, nil, parser.ImportsOnly) if err != nil { @@ -106,5 +121,49 @@ func scanDir(dir string, allowMain bool) (info *dirInfo, err os.Error) { imports[i] = p i++ } - return &dirInfo{goFiles, cgoFiles, cFiles, imports, pkgName}, nil + return &dirInfo{goFiles, cgoFiles, cFiles, sFiles, imports, pkgName}, nil +} + +// goodOSArch returns false if the filename contains a $GOOS or $GOARCH +// suffix which does not match the current system. +// The recognized filename formats are: +// +// name_$(GOOS).* +// name_$(GOARCH).* +// name_$(GOOS)_$(GOARCH).* +// +func goodOSArch(filename string) bool { + if dot := strings.Index(filename, "."); dot != -1 { + filename = filename[:dot] + } + l := strings.Split(filename, "_", -1) + n := len(l) + if n == 0 { + return true + } + if good, known := goodOS[l[n-1]]; known { + return good + } + if good, known := goodArch[l[n-1]]; known { + if !good || n < 2 { + return false + } + good, known = goodOS[l[n-2]] + return good || !known + } + return true +} + +var goodOS = make(map[string]bool) +var goodArch = make(map[string]bool) + +func init() { + goodOS = make(map[string]bool) + goodArch = make(map[string]bool) + for _, v := range strings.Fields(goosList) { + goodOS[v] = v == runtime.GOOS + } + for _, v := range strings.Fields(goarchList) { + goodArch[v] = v == runtime.GOARCH + } } diff --git a/src/cmd/goinstall/syslist_test.go b/src/cmd/goinstall/syslist_test.go new file mode 100644 index 000000000..795cd293a --- /dev/null +++ b/src/cmd/goinstall/syslist_test.go @@ -0,0 +1,61 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package main + +import ( + "runtime" + "testing" +) + +var ( + thisOS = runtime.GOOS + thisArch = runtime.GOARCH + otherOS = anotherOS() + otherArch = anotherArch() +) + +func anotherOS() string { + if thisOS != "darwin" { + return "darwin" + } + return "linux" +} + +func anotherArch() string { + if thisArch != "amd64" { + return "amd64" + } + return "386" +} + +type GoodFileTest struct { + name string + result bool +} + +var tests = []GoodFileTest{ + {"file.go", true}, + {"file.c", true}, + {"file_foo.go", true}, + {"file_" + thisArch + ".go", true}, + {"file_" + otherArch + ".go", false}, + {"file_" + thisOS + ".go", true}, + {"file_" + otherOS + ".go", false}, + {"file_" + thisOS + "_" + thisArch + ".go", true}, + {"file_" + otherOS + "_" + thisArch + ".go", false}, + {"file_" + thisOS + "_" + otherArch + ".go", false}, + {"file_" + otherOS + "_" + otherArch + ".go", false}, + {"file_foo_" + thisArch + ".go", true}, + {"file_foo_" + otherArch + ".go", false}, + {"file_" + thisOS + ".c", true}, + {"file_" + otherOS + ".c", false}, +} + +func TestGoodOSArch(t *testing.T) { + for _, test := range tests { + if goodOSArch(test.name) != test.result { + t.Fatalf("goodOSArch(%q) != %v", test.name, test.result) + } + } +} diff --git a/src/cmd/gopack/ar.c b/src/cmd/gopack/ar.c index a7e2c41af..dc3899f37 100644 --- a/src/cmd/gopack/ar.c +++ b/src/cmd/gopack/ar.c @@ -41,6 +41,7 @@ #include <libc.h> #include <bio.h> #include <mach.h> +#include "../../libmach/obj.h" #include <ar.h> #undef select @@ -123,6 +124,7 @@ int gflag; int oflag; int uflag; int vflag; +int Pflag; /* remove leading file prefix */ int Sflag; /* force mark Go package as safe */ int errors; @@ -141,6 +143,7 @@ char poname[ARNAMESIZE+1]; /* name of pivot member */ char *file; /* current file or member being worked on */ Biobuf bout; Biobuf bar; +char *prefix; void arcopy(Biobuf*, Arfile*, Armember*); int arcreate(char*); @@ -149,7 +152,7 @@ void arinsert(Arfile*, Armember*); void *armalloc(int); char *arstrdup(char*); void armove(Biobuf*, Arfile*, Armember*); -void arread(Biobuf*, Armember*, int); +void arread(Biobuf*, Armember*); void arstream(int, Arfile*); int arwrite(int, Armember*); int bamatch(char*, char*); @@ -179,6 +182,7 @@ void trim(char*, char*, int); void usage(void); void wrerr(void); void wrsym(Biobuf*, long, Arsymref*); +int arread_cutprefix(Biobuf*, Armember*); void rcmd(char*, int, char**); /* command processing */ void dcmd(char*, int, char**); @@ -220,6 +224,7 @@ main(int argc, char *argv[]) case 'v': vflag = 1; break; case 'x': setcom(xcmd); break; case 'S': Sflag = 1; break; + case 'P': Pflag = 1; break; default: fprint(2, "gopack: bad option `%c'\n", *cp); exits("error"); @@ -236,6 +241,15 @@ main(int argc, char *argv[]) if(argc < 3) usage(); } + if(Pflag) { + if(argc < 4) { + fprint(2, "gopack: P flag requires prefix argument\n"); + usage(); + } + prefix = argv[2]; + argv++; + argc--; + } if(comfun == 0) { if(uflag == 0) { fprint(2, "gopack: one of [%s] must be specified\n", man); @@ -313,7 +327,16 @@ rcmd(char *arname, int count, char **files) skip(&bar, bp->size); continue; } - if (count && !match(count, files)) { + /* + * the plan 9 ar treats count == 0 as equivalent + * to listing all the archive's files on the command line: + * it will try to open every file name in the archive + * and copy that file into the archive if it exists. + * for go we disable that behavior, because we use + * r with no files to make changes to the archive itself, + * using the S or P flags. + */ + if (!match(count, files)) { scanobj(&bar, ap, bp->size); arcopy(&bar, ap, bp); continue; @@ -430,7 +453,7 @@ xcmd(char *arname, int count, char **files) arcopy(&bar, 0, bp); if (write(f, bp->member, bp->size) < 0) wrerr(); - if(oflag) { + if(oflag && bp->date != 0) { nulldir(&dx); dx.atime = bp->date; dx.mtime = bp->date; @@ -972,7 +995,7 @@ phaseerr(int offset) void usage(void) { - fprint(2, "usage: gopack [%s][%s] archive files ...\n", opt, man); + fprint(2, "usage: gopack [%s][%s][P prefix] archive files ...\n", opt, man); exits("error"); } @@ -1012,29 +1035,32 @@ armove(Biobuf *b, Arfile *ap, Armember *bp) { char *cp; Dir *d; + vlong n; d = dirfstat(Bfildes(b)); if (d == nil) { fprint(2, "gopack: cannot stat %s\n", file); return; } + trim(file, bp->hdr.name, sizeof(bp->hdr.name)); for (cp = strchr(bp->hdr.name, 0); /* blank pad on right */ cp < bp->hdr.name+sizeof(bp->hdr.name); cp++) *cp = ' '; - sprint(bp->hdr.date, "%-12ld", d->mtime); + sprint(bp->hdr.date, "%-12ld", 0); // was d->mtime but removed for idempotent builds sprint(bp->hdr.uid, "%-6d", 0); sprint(bp->hdr.gid, "%-6d", 0); sprint(bp->hdr.mode, "%-8lo", d->mode); sprint(bp->hdr.size, "%-10lld", d->length); strncpy(bp->hdr.fmag, ARFMAG, 2); bp->size = d->length; - arread(b, bp, bp->size); - if (d->length&0x01) - d->length++; + arread(b, bp); + n = bp->size; + if (n&1) + n++; if (ap) { arinsert(ap, bp); - ap->size += d->length+SAR_HDR; + ap->size += n+SAR_HDR; } free(d); } @@ -1047,10 +1073,10 @@ arcopy(Biobuf *b, Arfile *ap, Armember *bp) { long n; + arread(b, bp); n = bp->size; if (n & 01) n++; - arread(b, bp, n); if (ap) { arinsert(ap, bp); ap->size += n+SAR_HDR; @@ -1125,7 +1151,7 @@ rl(int fd) len = symdefsize; if(len&01) len++; - sprint(a.date, "%-12ld", time(0)); + sprint(a.date, "%-12ld", 0); // time(0) sprint(a.uid, "%-6d", 0); sprint(a.gid, "%-6d", 0); sprint(a.mode, "%-8lo", 0644L); @@ -1162,7 +1188,7 @@ rl(int fd) if (gflag) { len = pkgdefsize; - sprint(a.date, "%-12ld", time(0)); + sprint(a.date, "%-12ld", 0); // time(0) sprint(a.uid, "%-6d", 0); sprint(a.gid, "%-6d", 0); sprint(a.mode, "%-8lo", 0644L); @@ -1316,7 +1342,8 @@ longt(Armember *bp) Bprint(&bout, "%7ld", bp->size); date = bp->date; cp = ctime(&date); - Bprint(&bout, " %-12.12s %-4.4s ", cp+4, cp+24); + /* using unix ctime, not plan 9 time, so cp+20 for year, not cp+24 */ + Bprint(&bout, " %-12.12s %-4.4s ", cp+4, cp+20); } int m1[] = { 1, ROWN, 'r', '-' }; @@ -1378,17 +1405,29 @@ newmember(void) /* allocate a member buffer */ } void -arread(Biobuf *b, Armember *bp, int n) /* read an image into a member buffer */ +arread(Biobuf *b, Armember *bp) /* read an image into a member buffer */ { int i; + vlong off; - bp->member = armalloc(n); - i = Bread(b, bp->member, n); + bp->member = armalloc(bp->size); + + // If P flag is set, let arread_cutprefix try. + // If it succeeds, we're done. If not, fall back + // to a direct copy. + off = Boffset(b); + if(Pflag && arread_cutprefix(b, bp)) + return; + Bseek(b, off, 0); + + i = Bread(b, bp->member, bp->size); if (i < 0) { free(bp->member); bp->member = 0; rderr(); } + if(bp->size&1) + Bgetc(b); } /* @@ -1551,3 +1590,109 @@ arstrdup(char *s) } +/* + * Parts of libmach we're not supposed + * to look at but need for arread_cutprefix. + */ +extern int _read5(Biobuf*, Prog*); +extern int _read6(Biobuf*, Prog*); +extern int _read8(Biobuf*, Prog*); +int (*reader[256])(Biobuf*, Prog*) = { + [ObjArm] = _read5, + [ObjAmd64] = _read6, + [Obj386] = _read8, +}; + +/* + * copy b into bp->member but rewrite object + * during copy to drop prefix from all file names. + * return 1 if b was recognized as an object file + * and copied successfully, 0 otherwise. + */ +int +arread_cutprefix(Biobuf *b, Armember *bp) +{ + vlong offset, o, end; + int n, t; + int (*rd)(Biobuf*, Prog*); + char *w, *inprefix; + Prog p; + + offset = Boffset(b); + end = offset + bp->size; + t = objtype(b, nil); + if(t < 0) + return 0; + if((rd = reader[t]) == nil) + return 0; + + // copy header + w = bp->member; + n = Boffset(b) - offset; + Bseek(b, -n, 1); + if(Bread(b, w, n) != n) + return 0; + offset += n; + w += n; + + // read object file one pseudo-instruction at a time, + // eliding the file name instructions that refer to + // the prefix. + memset(&p, 0, sizeof p); + inprefix = nil; + while(Boffset(b) < end && rd(b, &p)) { + if(p.kind == aName && p.type == UNKNOWN && p.sym == 1 && p.id[0] == '<') { + // part of a file path. + // we'll keep continuing (skipping the copy) + // around the loop until either we get to a + // name piece that should be kept or we see + // the whole prefix. + + if(inprefix == nil && prefix[0] == '/' && p.id[1] == '/' && p.id[2] == '\0') { + // leading / + inprefix = prefix+1; + } else if(inprefix != nil) { + // handle subsequent elements + n = strlen(p.id+1); + if(strncmp(p.id+1, inprefix, n) == 0 && (inprefix[n] == '/' || inprefix[n] == '\0')) { + inprefix += n; + if(inprefix[0] == '/') + inprefix++; + } + } + + if(inprefix && inprefix[0] == '\0') { + // reached end of prefix. + // if we another path element follows, + // nudge the offset to skip over the prefix we saw. + // if not, leave offset alone, to emit the whole name. + // additional name elements will not be skipped + // because inprefix is now nil and we won't see another + // leading / in this name. + inprefix = nil; + o = Boffset(b); + if(o < end && rd(b, &p) && p.kind == aName && p.type == UNKNOWN && p.sym == 1 && p.id[0] == '<') { + // print("skip %lld-%lld\n", offset, o); + offset = o; + } + } + } + + // copy instructions + if(!inprefix) { + n = Boffset(b) - offset; + Bseek(b, -n, 1); + if(Bread(b, w, n) != n) + return 0; + offset += n; + w += n; + } + } + bp->size = w - (char*)bp->member; + sprint(bp->hdr.size, "%-10lld", (vlong)bp->size); + strncpy(bp->hdr.fmag, ARFMAG, 2); + Bseek(b, end, 0); + if(Boffset(b)&1) + Bgetc(b); + return 1; +} diff --git a/src/cmd/gopack/doc.go b/src/cmd/gopack/doc.go index 08711e72e..1551a275f 100644 --- a/src/cmd/gopack/doc.go +++ b/src/cmd/gopack/doc.go @@ -12,12 +12,15 @@ It adds a special Go-specific section __.PKGDEF that collects all the Go type information from the files in the archive; that section is used by the compiler when importing the package during compilation. -Usage: gopack [uvnbailogS][mrxtdpq] archive files ... +Usage: gopack [uvnbailogS][mrxtdpq][P prefix] archive files ... The new option 'g' causes gopack to maintain the __.PKGDEF section as files are added to the archive. The new option 'S' forces gopack to mark the archive as safe. +The new option 'P' causes gopack to remove the given prefix +from file names in the line number information in object files +that are already stored in or added to the archive. */ package documentation diff --git a/src/cmd/gotest/Makefile b/src/cmd/gotest/Makefile index 74054e974..5c1154537 100644 --- a/src/cmd/gotest/Makefile +++ b/src/cmd/gotest/Makefile @@ -4,15 +4,9 @@ include ../../Make.inc -TARG=install - -clean: - @true - -install: install-gotest install-gotry - -install-%: % - ! test -f "$(GOBIN)"/$* || chmod u+w "$(GOBIN)"/$* - sed 's`@@GOROOT@@`$(GOROOT_FINAL)`' $* >"$(GOBIN)"/$* - chmod +x "$(GOBIN)"/$* +TARG=gotest +GOFILES=\ + flag.go\ + gotest.go\ +include ../../Make.cmd diff --git a/src/cmd/gotest/doc.go b/src/cmd/gotest/doc.go index 581eaaab9..9dba390c1 100644 --- a/src/cmd/gotest/doc.go +++ b/src/cmd/gotest/doc.go @@ -6,22 +6,23 @@ Gotest is an automated testing tool for Go packages. -Normally a Go package is compiled without its test files. Gotest -is a simple script that recompiles the package along with any files -named *_test.go. Functions in the test sources named TestXXX -(where XXX is any alphanumeric string starting with an upper case -letter) will be run when the binary is executed. Gotest requires -that the package have a standard package Makefile, one that -includes go/src/Make.pkg. +Normally a Go package is compiled without its test files. Gotest is a +tool that recompiles the package whose source is in the current +directory, along with any files whose names match the pattern +"[^.]*_test.go". Functions in the test source named TestXXX (where +XXX is any alphanumeric string not starting with a lower case letter) +will be run when the binary is executed. Gotest requires that the +package have a standard package Makefile, one that includes +go/src/Make.pkg. The test functions are run in the order they appear in the source. -They should have signature +They should have the signature, func TestXXX(t *testing.T) { ... } -Benchmark functions can be written as well; they will be run only -when the -test.bench flag is provided. Benchmarks should have -signature +Benchmark functions can be written as well; they will be run only when +the -test.bench flag is provided. Benchmarks should have the +signature, func BenchmarkXXX(b *testing.B) { ... } @@ -29,28 +30,75 @@ See the documentation of the testing package for more information. By default, gotest needs no arguments. It compiles all the .go files in the directory, including tests, and runs the tests. If file names -are given, only those test files are added to the package. -(The non-test files are always compiled.) +are given (with flag -file=test.go, one per extra test source file), +only those test files are added to the package. (The non-test files +are always compiled.) The package is built in a special subdirectory so it does not interfere with the non-test installation. Usage: - gotest [pkg_test.go ...] + gotest [-file a.go -file b.go ...] [-c] [-x] [args for test binary] -The resulting binary, called (for amd64) 6.out, has a couple of -arguments. +The flags specific to gotest are: + -c Compile the test binary but do not run it. + -file a.go Use only the tests in the source file a.go. + Multiple -file flags may be provided. + -x Print each subcommand gotest executes. + +Everything else on the command line is passed to the test binary. + +The resulting test binary, called (for amd64) 6.out, has several flags. Usage: - 6.out [-test.v] [-test.run pattern] [-test.bench pattern] + 6.out [-test.v] [-test.run pattern] [-test.bench pattern] \ + [-test.cpuprofile=cpu.out] \ + [-test.memprofile=mem.out] [-test.memprofilerate=1] The -test.v flag causes the tests to be logged as they run. The -test.run flag causes only those tests whose names match the regular -expression pattern to be run. By default all tests are run silently. -If all the specified test pass, 6.out prints PASS and exits with a 0 -exit code. If any tests fail, it prints FAIL and exits with a -non-zero code. The -test.bench flag is analogous to the -test.run -flag, but applies to benchmarks. No benchmarks run by default. +expression pattern to be run. By default all tests are run silently. + +If all specified tests pass, 6.out prints the word PASS and exits with +a 0 exit code. If any tests fail, it prints error details, the word +FAIL, and exits with a non-zero code. The -test.bench flag is +analogous to the -test.run flag, but applies to benchmarks. No +benchmarks run by default. + +The -test.cpuprofile flag causes the testing software to write a CPU +profile to the specified file before exiting. + +The -test.memprofile flag causes the testing software to write a +memory profile to the specified file when all tests are complete. The +-test.memprofilerate flag enables more precise (and expensive) +profiles by setting runtime.MemProfileRate; run + godoc runtime MemProfileRate +for details. The defaults are no memory profile and the standard +setting of MemProfileRate. The memory profile records a sampling of +the memory in use at the end of the test. To profile all memory +allocations, use -test.memprofilerate=1 to sample every byte and set +the environment variable GOGC=off to disable the garbage collector, +provided the test can run in the available memory without garbage +collection. + +Use -test.run or -test.bench to limit profiling to a particular test +or benchmark. + +The -test.short flag tells long-running tests to shorten their run +time. It is off by default but set by all.bash so installations of +the Go tree can do a sanity check but not spend time running +exhaustive tests. + +The -test.timeout flag sets a timeout for the test in seconds. If the +test runs for longer than that, it will panic, dumping a stack trace +of all existing goroutines. + +For convenience, each of these -test.X flags of the test binary is +also available as the flag -X in gotest itself. Flags not listed here +are unaffected. For instance, the command + gotest -x -v -cpuprofile=prof.out -dir=testdata -update -file x_test.go +will compile the test binary using x_test.go and then run it as + 6.out -test.v -test.cpuprofile=prof.out -dir=testdata -update */ package documentation diff --git a/src/cmd/gotest/flag.go b/src/cmd/gotest/flag.go new file mode 100644 index 000000000..780c78b9c --- /dev/null +++ b/src/cmd/gotest/flag.go @@ -0,0 +1,155 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// The flag handling part of gotest is large and distracting. +// We can't use the flag package because some of the flags from +// our command line are for us, and some are for 6.out, and +// some are for both. + +var usageMessage = `Usage of %s: + -c=false: compile but do not run the test binary + -file=file: + -x=false: print command lines as they are executed + + // These flags can be passed with or without a "test." prefix: -v or -test.v. + -bench="": passes -test.bench to test + -cpuprofile="": passes -test.cpuprofile to test + -memprofile="": passes -test.memprofile to test + -memprofilerate=0: passes -test.memprofilerate to test + -run="": passes -test.run to test + -short=false: passes -test.short to test + -timeout=0: passes -test.timeout to test + -v=false: passes -test.v to test +` + +// usage prints a usage message and exits. +func usage() { + fmt.Fprintf(os.Stdout, usageMessage, os.Args[0]) + os.Exit(2) +} + +// flagSpec defines a flag we know about. +type flagSpec struct { + name string + isBool bool + passToTest bool // pass to Test + multiOK bool // OK to have multiple instances + present bool // flag has been seen +} + +// flagDefn is the set of flags we process. +var flagDefn = []*flagSpec{ + // gotest-local. + &flagSpec{name: "c", isBool: true}, + &flagSpec{name: "file", multiOK: true}, + &flagSpec{name: "x", isBool: true}, + + // passed to 6.out, adding a "test." prefix to the name if necessary: -v becomes -test.v. + &flagSpec{name: "bench", passToTest: true}, + &flagSpec{name: "cpuprofile", passToTest: true}, + &flagSpec{name: "memprofile", passToTest: true}, + &flagSpec{name: "memprofilerate", passToTest: true}, + &flagSpec{name: "run", passToTest: true}, + &flagSpec{name: "short", isBool: true, passToTest: true}, + &flagSpec{name: "timeout", passToTest: true}, + &flagSpec{name: "v", isBool: true, passToTest: true}, +} + +// flags processes the command line, grabbing -x and -c, rewriting known flags +// to have "test" before them, and reading the command line for the 6.out. +// Unfortunately for us, we need to do our own flag processing because gotest +// grabs some flags but otherwise its command line is just a holding place for +// 6.out's arguments. +func flags() { + for i := 1; i < len(os.Args); i++ { + arg := os.Args[i] + f, value, extraWord := flag(i) + if f == nil { + args = append(args, arg) + continue + } + switch f.name { + case "c": + setBoolFlag(&cFlag, value) + case "x": + setBoolFlag(&xFlag, value) + case "file": + fileNames = append(fileNames, value) + } + if extraWord { + i++ + } + if f.passToTest { + args = append(args, "-test."+f.name+"="+value) + } + } +} + +// flag sees if argument i is a known flag and returns its definition, value, and whether it consumed an extra word. +func flag(i int) (f *flagSpec, value string, extra bool) { + arg := os.Args[i] + if strings.HasPrefix(arg, "--") { // reduce two minuses to one + arg = arg[1:] + } + if arg == "" || arg[0] != '-' { + return + } + name := arg[1:] + // If there's already "test.", drop it for now. + if strings.HasPrefix(name, "test.") { + name = name[5:] + } + equals := strings.Index(name, "=") + if equals >= 0 { + value = name[equals+1:] + name = name[:equals] + } + for _, f = range flagDefn { + if name == f.name { + // Booleans are special because they have modes -x, -x=true, -x=false. + if f.isBool { + if equals < 0 { // otherwise, it's been set and will be verified in setBoolFlag + value = "true" + } else { + // verify it parses + setBoolFlag(new(bool), value) + } + } else { // Non-booleans must have a value. + extra = equals < 0 + if extra { + if i+1 >= len(os.Args) { + usage() + } + value = os.Args[i+1] + } + } + if f.present && !f.multiOK { + usage() + } + f.present = true + return + } + } + f = nil + return +} + +// setBoolFlag sets the addressed boolean to the value. +func setBoolFlag(flag *bool, value string) { + x, err := strconv.Atob(value) + if err != nil { + fmt.Fprintf(os.Stderr, "gotest: illegal bool flag value %s\n", value) + usage() + } + *flag = x +} diff --git a/src/cmd/gotest/gotest b/src/cmd/gotest/gotest deleted file mode 100755 index 69eaae730..000000000 --- a/src/cmd/gotest/gotest +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2009 The Go Authors. All rights reserved. -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file. - -# Using all the *_test.go files in the current directory, write out a file -# _testmain.go that runs all its tests. Compile everything and run the -# tests. -# If files are named on the command line, use them instead of *_test.go. - -# Makes egrep,grep work better in general if we put them -# in ordinary C mode instead of what the current language is. -unset LANG -export LC_ALL=C -export LC_CTYPE=C - -_GC=$GC # Make.inc will overwrite this - -if [ ! -f [Mm]akefile ]; then - echo 'please create a Makefile for gotest; see http://golang.org/doc/code.html for details' 1>&2 - exit 2 -fi - -export GOROOT=${GOROOT:-"@@GOROOT@@"} -eval $(gomake -j1 --no-print-directory -f "$GOROOT"/src/Make.inc go-env) -if [ -z "$O" ]; then - echo 'missing $O - maybe no Make.$GOARCH?' 1>&2 - exit 2 -fi - -E="$GORUN" - -# Allow overrides -GC="${_GC:-$GC} -I _test" -GL="${GL:-$LD} -L _test" -AS="$AS" -CC="$CC" -LD="$LD" -export GC GL O AS CC LD - -gofiles="" -loop=true -while $loop; do - case "x$1" in - x-*) - loop=false - ;; - x) - loop=false - ;; - *) - gofiles="$gofiles $1" - shift - ;; - esac -done - -case "x$gofiles" in -x) - gofiles=$(echo -n $(ls *_test.go 2>/dev/null)) -esac - -case "x$gofiles" in -x) - echo 'no test files found (*_test.go)' 1>&2 - exit 2 -esac - -# Run any commands given in sources, like -# // gotest: $GC foo.go -# to build any test-only dependencies. -sed -n 's/^\/\/ gotest: //p' $gofiles | sh -e || exit 1 - -# Split $gofiles into external gofiles (those in *_test packages) -# and internal ones (those in the main package). -xgofiles=$(echo $(grep '^package[ ]' $gofiles /dev/null | grep ':.*_test' | sed 's/:.*//')) -gofiles=$(echo $(grep '^package[ ]' $gofiles /dev/null | grep -v ':.*_test' | sed 's/:.*//')) - -# External $O file -xofile="" -havex=false -if [ "x$xgofiles" != "x" ]; then - xofile="_xtest_.$O" - havex=true -fi - -set -e - -gomake testpackage-clean -gomake testpackage "GOTESTFILES=$gofiles" -if $havex; then - $GC -o $xofile $xgofiles -fi - -# They all compile; now generate the code to call them. - -# Suppress output to stdout on Linux -MAKEFLAGS= -MAKELEVEL= - -# usage: nmgrep pattern file... -nmgrep() { - pat="$1" - shift - for i - do - # Package symbol "".Foo is pkg.Foo when imported in Go. - # Figure out pkg. - case "$i" in - *.a) - pkg=$(gopack p $i __.PKGDEF | sed -n 's/^package //p' | sed 's/ .*//' | sed 1q) - ;; - *) - pkg=$(sed -n 's/^ .* in package "\(.*\)".*/\1/p' $i | sed 1q) - ;; - esac - 6nm -s "$i" | egrep ' T .*\.'"$pat"'$' | - sed 's/.* //; /\..*\./d; s/""\./'"$pkg"'./g' - done -} - -localname() { - # The package main has been renamed to __main__ when imported. - # Adjust its uses. - echo $1 | sed 's/^main\./__main__./' -} - -importpath=$(gomake -s importpath) -{ - # test functions are named TestFoo - # the grep -v eliminates methods and other special names - # that have multiple dots. - pattern='Test([^a-z].*)?' - tests=$(nmgrep $pattern _test/$importpath.a $xofile) - if [ "x$tests" = x ]; then - echo 'gotest: error: no tests matching '$pattern in _test/$importpath.a $xofile 1>&2 - exit 2 - fi - # benchmarks are named BenchmarkFoo. - pattern='Benchmark([^a-z].*)?' - benchmarks=$(nmgrep $pattern _test/$importpath.a $xofile) - - # package spec - echo 'package main' - echo - # imports - if echo "$tests" | egrep -v '_test\.' >/dev/null; then - case "$importpath" in - testing) - ;; - main) - # Import path main is reserved, so import with - # explicit reference to ./_test/main instead. - # Also, the file we are writing defines a function named main, - # so rename this import to __main__ to avoid name conflict. - echo 'import __main__ "./_test/main"' - ;; - *) - echo 'import "'$importpath'"' - ;; - esac - fi - if $havex; then - echo 'import "./_xtest_"' - fi - echo 'import "testing"' - echo 'import __regexp__ "regexp"' # rename in case tested package is called regexp - # test array - echo - echo 'var tests = []testing.InternalTest{' - for i in $tests - do - j=$(localname $i) - echo ' {"'$i'", '$j'},' - done - echo '}' - # benchmark array - # The comment makes the multiline declaration - # gofmt-safe even when there are no benchmarks. - echo 'var benchmarks = []testing.InternalBenchmark{ //' - for i in $benchmarks - do - j=$(localname $i) - echo ' {"'$i'", '$j'},' - done - echo '}' - # body - echo - echo 'func main() {' - echo ' testing.Main(__regexp__.MatchString, tests)' - echo ' testing.RunBenchmarks(__regexp__.MatchString, benchmarks)' - echo '}' -}>_testmain.go - -$GC _testmain.go -$GL _testmain.$O -$E ./$O.out "$@" diff --git a/src/cmd/gotest/gotest.go b/src/cmd/gotest/gotest.go new file mode 100644 index 000000000..138216e68 --- /dev/null +++ b/src/cmd/gotest/gotest.go @@ -0,0 +1,415 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bufio" + "exec" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" + "unicode" + "utf8" +) + +// Environment for commands. +var ( + XGC []string // 6g -I _test -o _xtest_.6 + GC []string // 6g -I _test _testmain.go + GL []string // 6l -L _test _testmain.6 + GOARCH string + GOROOT string + GORUN string + O string + args []string // arguments passed to gotest; also passed to the binary + fileNames []string + env = os.Environ() +) + +// These strings are created by getTestNames. +var ( + insideFileNames []string // list of *.go files inside the package. + outsideFileNames []string // list of *.go files outside the package (in package foo_test). +) + +var ( + files []*File + importPath string +) + +// Flags for our own purposes. We do our own flag processing. +var ( + cFlag bool + xFlag bool +) + +// File represents a file that contains tests. +type File struct { + name string + pkg string + file *os.File + astFile *ast.File + tests []string // The names of the TestXXXs. + benchmarks []string // The names of the BenchmarkXXXs. +} + +func main() { + flags() + needMakefile() + setEnvironment() + getTestFileNames() + parseFiles() + getTestNames() + run("gomake", "testpackage-clean") + run("gomake", "testpackage", fmt.Sprintf("GOTESTFILES=%s", strings.Join(insideFileNames, " "))) + if len(outsideFileNames) > 0 { + run(append(XGC, outsideFileNames...)...) + } + importPath = runWithStdout("gomake", "-s", "importpath") + writeTestmainGo() + run(GC...) + run(GL...) + if !cFlag { + runTestWithArgs("./" + O + ".out") + } +} + +// needMakefile tests that we have a Makefile in this directory. +func needMakefile() { + if _, err := os.Stat("Makefile"); err != nil { + Fatalf("please create a Makefile for gotest; see http://golang.org/doc/code.html for details") + } +} + +// Fatalf formats its arguments, prints the message with a final newline, and exits. +func Fatalf(s string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "gotest: "+s+"\n", args...) + os.Exit(2) +} + +// theChar is the map from architecture to object character. +var theChar = map[string]string{ + "arm": "5", + "amd64": "6", + "386": "8", +} + +// addEnv adds a name=value pair to the environment passed to subcommands. +// If the item is already in the environment, addEnv replaces the value. +func addEnv(name, value string) { + for i := 0; i < len(env); i++ { + if strings.HasPrefix(env[i], name+"=") { + env[i] = name + "=" + value + return + } + } + env = append(env, name+"="+value) +} + +// setEnvironment assembles the configuration for gotest and its subcommands. +func setEnvironment() { + // Basic environment. + GOROOT = runtime.GOROOT() + addEnv("GOROOT", GOROOT) + GOARCH = runtime.GOARCH + addEnv("GOARCH", GOARCH) + O = theChar[GOARCH] + if O == "" { + Fatalf("unknown architecture %s", GOARCH) + } + + // Commands and their flags. + gc := os.Getenv("GC") + if gc == "" { + gc = O + "g" + } + XGC = []string{gc, "-I", "_test", "-o", "_xtest_." + O} + GC = []string{gc, "-I", "_test", "_testmain.go"} + gl := os.Getenv("GL") + if gl == "" { + gl = O + "l" + } + GL = []string{gl, "-L", "_test", "_testmain." + O} + + // Silence make on Linux + addEnv("MAKEFLAGS", "") + addEnv("MAKELEVEL", "") +} + +// getTestFileNames gets the set of files we're looking at. +// If gotest has no arguments, it scans for file names matching "[^.]*_test.go". +func getTestFileNames() { + names := fileNames + if len(names) == 0 { + var err os.Error + names, err = filepath.Glob("[^.]*_test.go") + if err != nil { + Fatalf("Glob pattern error: %s", err) + } + if len(names) == 0 { + Fatalf(`no test files found: no match for "[^.]*_test.go"`) + } + } + for _, n := range names { + fd, err := os.Open(n) + if err != nil { + Fatalf("%s: %s", n, err) + } + f := &File{name: n, file: fd} + files = append(files, f) + } +} + +// parseFiles parses the files and remembers the packages we find. +func parseFiles() { + fileSet := token.NewFileSet() + for _, f := range files { + // Report declaration errors so we can abort if the files are incorrect Go. + file, err := parser.ParseFile(fileSet, f.name, nil, parser.DeclarationErrors) + if err != nil { + Fatalf("parse error: %s", err) + } + f.astFile = file + f.pkg = file.Name.String() + if f.pkg == "" { + Fatalf("cannot happen: no package name in %s", f.name) + } + } +} + +// getTestNames extracts the names of tests and benchmarks. They are all +// top-level functions that are not methods. +func getTestNames() { + for _, f := range files { + for _, d := range f.astFile.Decls { + n, ok := d.(*ast.FuncDecl) + if !ok { + continue + } + if n.Recv != nil { // a method, not a function. + continue + } + name := n.Name.String() + if isTest(name, "Test") { + f.tests = append(f.tests, name) + } else if isTest(name, "Benchmark") { + f.benchmarks = append(f.benchmarks, name) + } + // TODO: worth checking the signature? Probably not. + } + if strings.HasSuffix(f.pkg, "_test") { + outsideFileNames = append(outsideFileNames, f.name) + } else { + insideFileNames = append(insideFileNames, f.name) + } + } +} + +// isTest tells whether name looks like a test (or benchmark, according to prefix). +// It is a Test (say) if there is a character after Test that is not a lower-case letter. +// We don't want TesticularCancer. +func isTest(name, prefix string) bool { + if !strings.HasPrefix(name, prefix) { + return false + } + if len(name) == len(prefix) { // "Test" is ok + return true + } + rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) + return !unicode.IsLower(rune) +} + +func run(args ...string) { + doRun(args, false) +} + +// runWithStdout is like run, but returns the text of standard output with the last newline dropped. +func runWithStdout(argv ...string) string { + s := doRun(argv, true) + if strings.HasSuffix(s, "\r\n") { + s = s[:len(s)-2] + } else if strings.HasSuffix(s, "\n") { + s = s[:len(s)-1] + } + if len(s) == 0 { + Fatalf("no output from command %s", strings.Join(argv, " ")) + } + return s +} + +// runTestWithArgs appends gotest's runs the provided binary with the args passed on the command line. +func runTestWithArgs(binary string) { + doRun(append([]string{binary}, args...), false) +} + +// doRun is the general command runner. The flag says whether we want to +// retrieve standard output. +func doRun(argv []string, returnStdout bool) string { + if xFlag { + fmt.Printf("gotest: %s\n", strings.Join(argv, " ")) + } + command := argv[0] + if runtime.GOOS == "windows" && command == "gomake" { + // gomake is a shell script and it cannot be executed directly on Windows. + cmd := "" + for i, v := range argv { + if i > 0 { + cmd += " " + } + cmd += `"` + v + `"` + } + argv = []string{"cmd", "/c", "sh", "-c", cmd} + } + var err os.Error + argv[0], err = exec.LookPath(argv[0]) + if err != nil { + Fatalf("can't find %s: %s", command, err) + } + procAttr := &os.ProcAttr{ + Env: env, + Files: []*os.File{ + os.Stdin, + os.Stdout, + os.Stderr, + }, + } + var r, w *os.File + if returnStdout { + r, w, err = os.Pipe() + if err != nil { + Fatalf("can't create pipe: %s", err) + } + procAttr.Files[1] = w + } + proc, err := os.StartProcess(argv[0], argv, procAttr) + if err != nil { + Fatalf("%s failed to start: %s", command, err) + } + if returnStdout { + defer r.Close() + w.Close() + } + waitMsg, err := proc.Wait(0) + if err != nil || waitMsg == nil { + Fatalf("%s failed: %s", command, err) + } + if !waitMsg.Exited() || waitMsg.ExitStatus() != 0 { + Fatalf("%q failed: %s", strings.Join(argv, " "), waitMsg) + } + if returnStdout { + b, err := ioutil.ReadAll(r) + if err != nil { + Fatalf("can't read output from command: %s", err) + } + return string(b) + } + return "" +} + +// writeTestmainGo generates the test program to be compiled, "./_testmain.go". +func writeTestmainGo() { + f, err := os.Create("_testmain.go") + if err != nil { + Fatalf("can't create _testmain.go: %s", err) + } + defer f.Close() + b := bufio.NewWriter(f) + defer b.Flush() + + // Package and imports. + fmt.Fprint(b, "package main\n\n") + // Are there tests from a package other than the one we're testing? + // We can't just use file names because some of the things we compiled + // contain no tests. + outsideTests := false + insideTests := false + for _, f := range files { + //println(f.name, f.pkg) + if len(f.tests) == 0 && len(f.benchmarks) == 0 { + continue + } + if strings.HasSuffix(f.pkg, "_test") { + outsideTests = true + } else { + insideTests = true + } + } + if insideTests { + switch importPath { + case "testing": + case "main": + // Import path main is reserved, so import with + // explicit reference to ./_test/main instead. + // Also, the file we are writing defines a function named main, + // so rename this import to __main__ to avoid name conflict. + fmt.Fprintf(b, "import __main__ %q\n", "./_test/main") + default: + fmt.Fprintf(b, "import %q\n", importPath) + } + } + if outsideTests { + fmt.Fprintf(b, "import %q\n", "./_xtest_") + } + fmt.Fprintf(b, "import %q\n", "testing") + fmt.Fprintf(b, "import __os__ %q\n", "os") // rename in case tested package is called os + fmt.Fprintf(b, "import __regexp__ %q\n", "regexp") // rename in case tested package is called regexp + fmt.Fprintln(b) // for gofmt + + // Tests. + fmt.Fprintln(b, "var tests = []testing.InternalTest{") + for _, f := range files { + for _, t := range f.tests { + fmt.Fprintf(b, "\t{\"%s.%s\", %s.%s},\n", f.pkg, t, notMain(f.pkg), t) + } + } + fmt.Fprintln(b, "}") + fmt.Fprintln(b) + + // Benchmarks. + fmt.Fprintln(b, "var benchmarks = []testing.InternalBenchmark{") + for _, f := range files { + for _, bm := range f.benchmarks { + fmt.Fprintf(b, "\t{\"%s.%s\", %s.%s},\n", f.pkg, bm, notMain(f.pkg), bm) + } + } + fmt.Fprintln(b, "}") + + // Body. + fmt.Fprintln(b, testBody) +} + +// notMain returns the package, renaming as appropriate if it's "main". +func notMain(pkg string) string { + if pkg == "main" { + return "__main__" + } + return pkg +} + +// testBody is just copied to the output. It's the code that runs the tests. +var testBody = ` +var matchPat string +var matchRe *__regexp__.Regexp + +func matchString(pat, str string) (result bool, err __os__.Error) { + if matchRe == nil || matchPat != pat { + matchPat = pat + matchRe, err = __regexp__.Compile(matchPat) + if err != nil { + return + } + } + return matchRe.MatchString(str), nil +} + +func main() { + testing.Main(matchString, tests, benchmarks) +}` diff --git a/src/cmd/gotry/Makefile b/src/cmd/gotry/Makefile new file mode 100644 index 000000000..6a32bbf2d --- /dev/null +++ b/src/cmd/gotry/Makefile @@ -0,0 +1,18 @@ +# Copyright 2010 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../Make.inc + +TARG=install + +clean: + @true + +install: install-gotry + +install-%: % + ! test -f "$(GOBIN)"/$* || chmod u+w "$(GOBIN)"/$* + sed 's`@@GOROOT@@`$(GOROOT_FINAL)`' $* >"$(GOBIN)"/$* + chmod +x "$(GOBIN)"/$* + diff --git a/src/cmd/gotest/gotry b/src/cmd/gotry/gotry index 52c5d2d58..52c5d2d58 100755 --- a/src/cmd/gotest/gotry +++ b/src/cmd/gotry/gotry diff --git a/src/cmd/gotype/Makefile b/src/cmd/gotype/Makefile new file mode 100644 index 000000000..18171945d --- /dev/null +++ b/src/cmd/gotype/Makefile @@ -0,0 +1,17 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../Make.inc + +TARG=gotype +GOFILES=\ + gotype.go\ + +include ../../Make.cmd + +test: + gotest + +testshort: + gotest -test.short diff --git a/src/cmd/gotype/doc.go b/src/cmd/gotype/doc.go new file mode 100644 index 000000000..ec4eb7c24 --- /dev/null +++ b/src/cmd/gotype/doc.go @@ -0,0 +1,59 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +The gotype command does syntactic and semantic analysis of Go files +and packages similar to the analysis performed by the front-end of +a Go compiler. Errors are reported if the analysis fails; otherwise +gotype is quiet (unless -v is set). + +Without a list of paths, gotype processes the standard input, which must +be the source of a single package file. + +Given a list of file names, each file must be a source file belonging to +the same package unless the package name is explicitly specified with the +-p flag. + +Given a directory name, gotype collects all .go files in the directory +and processes them as if they were provided as an explicit list of file +names. Each directory is processed independently. Files starting with . +or not ending in .go are ignored. + +Usage: + gotype [flags] [path ...] + +The flags are: + -p pkgName + process only those files in package pkgName. + -r + recursively process subdirectories. + -v + verbose mode. + +Debugging flags: + -trace + print parse trace (disables concurrent parsing). + -ast + print AST (disables concurrent parsing). + + +Examples + +To check the files file.go, old.saved, and .ignored: + + gotype file.go old.saved .ignored + +To check all .go files belonging to package main in the current directory +and recursively in all subdirectories: + + gotype -p main -r . + +To verify the output of a pipe: + + echo "package foo" | gotype + +*/ +package documentation + +// BUG(gri): At the moment, only single-file scope analysis is performed. diff --git a/src/cmd/gotype/gotype.go b/src/cmd/gotype/gotype.go new file mode 100644 index 000000000..568467322 --- /dev/null +++ b/src/cmd/gotype/gotype.go @@ -0,0 +1,198 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/scanner" + "go/token" + "go/types" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + + +var ( + // main operation modes + pkgName = flag.String("p", "", "process only those files in package pkgName") + recursive = flag.Bool("r", false, "recursively process subdirectories") + verbose = flag.Bool("v", false, "verbose mode") + + // debugging support + printTrace = flag.Bool("trace", false, "print parse trace") + printAST = flag.Bool("ast", false, "print AST") +) + + +var exitCode = 0 + + +func usage() { + fmt.Fprintf(os.Stderr, "usage: gotype [flags] [path ...]\n") + flag.PrintDefaults() + os.Exit(2) +} + + +func report(err os.Error) { + scanner.PrintError(os.Stderr, err) + exitCode = 2 +} + + +// parse returns the AST for the Go source src. +// The filename is for error reporting only. +// The result is nil if there were errors or if +// the file does not belong to the -p package. +func parse(fset *token.FileSet, filename string, src []byte) *ast.File { + if *verbose { + fmt.Println(filename) + } + + // ignore files with different package name + if *pkgName != "" { + file, err := parser.ParseFile(fset, filename, src, parser.PackageClauseOnly) + if err != nil { + report(err) + return nil + } + if file.Name.Name != *pkgName { + if *verbose { + fmt.Printf("\tignored (package %s)\n", file.Name.Name) + } + return nil + } + } + + // parse entire file + mode := parser.DeclarationErrors + if *printTrace { + mode |= parser.Trace + } + file, err := parser.ParseFile(fset, filename, src, mode) + if err != nil { + report(err) + return nil + } + if *printAST { + ast.Print(fset, file) + } + + return file +} + + +func parseStdin(fset *token.FileSet) (files map[string]*ast.File) { + files = make(map[string]*ast.File) + src, err := ioutil.ReadAll(os.Stdin) + if err != nil { + report(err) + return + } + const filename = "<standard input>" + if file := parse(fset, filename, src); file != nil { + files[filename] = file + } + return +} + + +func parseFiles(fset *token.FileSet, filenames []string) (files map[string]*ast.File) { + files = make(map[string]*ast.File) + for _, filename := range filenames { + src, err := ioutil.ReadFile(filename) + if err != nil { + report(err) + continue + } + if file := parse(fset, filename, src); file != nil { + if files[filename] != nil { + report(os.ErrorString(fmt.Sprintf("%q: duplicate file", filename))) + continue + } + files[filename] = file + } + } + return +} + + +func isGoFilename(filename string) bool { + // ignore non-Go files + return !strings.HasPrefix(filename, ".") && strings.HasSuffix(filename, ".go") +} + + +func processDirectory(dirname string) { + f, err := os.Open(dirname) + if err != nil { + report(err) + return + } + filenames, err := f.Readdirnames(-1) + f.Close() + if err != nil { + report(err) + // continue since filenames may not be empty + } + for i, filename := range filenames { + filenames[i] = filepath.Join(dirname, filename) + } + processFiles(filenames, false) +} + + +func processFiles(filenames []string, allFiles bool) { + i := 0 + for _, filename := range filenames { + switch info, err := os.Stat(filename); { + case err != nil: + report(err) + case info.IsRegular(): + if allFiles || isGoFilename(info.Name) { + filenames[i] = filename + i++ + } + case info.IsDirectory(): + if allFiles || *recursive { + processDirectory(filename) + } + } + } + fset := token.NewFileSet() + processPackage(fset, parseFiles(fset, filenames[0:i])) +} + + +func processPackage(fset *token.FileSet, files map[string]*ast.File) { + // make a package (resolve all identifiers) + pkg, err := ast.NewPackage(fset, files, types.GcImporter, types.Universe) + if err != nil { + report(err) + return + } + // TODO(gri): typecheck package + _ = pkg +} + + +func main() { + flag.Usage = usage + flag.Parse() + + if flag.NArg() == 0 { + fset := token.NewFileSet() + processPackage(fset, parseStdin(fset)) + } else { + processFiles(flag.Args(), true) + } + + os.Exit(exitCode) +} diff --git a/src/cmd/gotype/gotype_test.go b/src/cmd/gotype/gotype_test.go new file mode 100644 index 000000000..9c8f8f2a7 --- /dev/null +++ b/src/cmd/gotype/gotype_test.go @@ -0,0 +1,52 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "path/filepath" + "runtime" + "testing" +) + + +func runTest(t *testing.T, path, pkg string) { + exitCode = 0 + *pkgName = pkg + *recursive = false + + if pkg == "" { + processFiles([]string{path}, true) + } else { + processDirectory(path) + } + + if exitCode != 0 { + t.Errorf("processing %s failed: exitCode = %d", path, exitCode) + } +} + + +var tests = []struct { + path string + pkg string +}{ + // individual files + {"testdata/test1.go", ""}, + + // directories + {filepath.Join(runtime.GOROOT(), "src/pkg/go/ast"), "ast"}, + {filepath.Join(runtime.GOROOT(), "src/pkg/go/doc"), "doc"}, + {filepath.Join(runtime.GOROOT(), "src/pkg/go/token"), "scanner"}, + {filepath.Join(runtime.GOROOT(), "src/pkg/go/scanner"), "scanner"}, + {filepath.Join(runtime.GOROOT(), "src/pkg/go/parser"), "parser"}, + {filepath.Join(runtime.GOROOT(), "src/pkg/go/types"), "types"}, +} + + +func Test(t *testing.T) { + for _, test := range tests { + runTest(t, test.path, test.pkg) + } +} diff --git a/src/cmd/gotype/testdata/test1.go b/src/cmd/gotype/testdata/test1.go new file mode 100644 index 000000000..0bd46568d --- /dev/null +++ b/src/cmd/gotype/testdata/test1.go @@ -0,0 +1,6 @@ +package p + +func _() { + // the scope of a local type declaration starts immediately after the type name + type T struct{ _ *T } +} diff --git a/src/cmd/govet/govet.go b/src/cmd/govet/govet.go index ff6421de8..b811c61a2 100644 --- a/src/cmd/govet/govet.go +++ b/src/cmd/govet/govet.go @@ -7,7 +7,6 @@ package main import ( - "bytes" "flag" "fmt" "io" @@ -18,6 +17,7 @@ import ( "path/filepath" "strconv" "strings" + "utf8" ) var verbose = flag.Bool("v", false, "verbose") @@ -63,6 +63,7 @@ func main() { } name = name[:colon] } + name = strings.ToLower(name) if name[len(name)-1] == 'f' { printfList[name] = skip } else { @@ -205,35 +206,38 @@ func (f *File) checkCallExpr(call *ast.CallExpr) { } // printfList records the formatted-print functions. The value is the location -// of the format parameter. +// of the format parameter. Names are lower-cased so the lookup is +// case insensitive. var printfList = map[string]int{ - "Errorf": 0, - "Fatalf": 0, - "Fprintf": 1, - "Panicf": 0, - "Printf": 0, - "Sprintf": 0, + "errorf": 0, + "fatalf": 0, + "fprintf": 1, + "panicf": 0, + "printf": 0, + "sprintf": 0, } // printList records the unformatted-print functions. The value is the location -// of the first parameter to be printed. +// of the first parameter to be printed. Names are lower-cased so the lookup is +// case insensitive. var printList = map[string]int{ - "Error": 0, - "Fatal": 0, - "Fprint": 1, "Fprintln": 1, - "Panic": 0, "Panicln": 0, - "Print": 0, "Println": 0, - "Sprint": 0, "Sprintln": 0, + "error": 0, + "fatal": 0, + "fprint": 1, "fprintln": 1, + "panic": 0, "panicln": 0, + "print": 0, "println": 0, + "sprint": 0, "sprintln": 0, } // checkCall triggers the print-specific checks if the call invokes a print function. -func (f *File) checkCall(call *ast.CallExpr, name string) { +func (f *File) checkCall(call *ast.CallExpr, Name string) { + name := strings.ToLower(Name) if skip, ok := printfList[name]; ok { - f.checkPrintf(call, name, skip) + f.checkPrintf(call, Name, skip) return } if skip, ok := printList[name]; ok { - f.checkPrint(call, name, skip) + f.checkPrint(call, Name, skip) return } } @@ -256,7 +260,7 @@ func (f *File) checkPrintf(call *ast.CallExpr, name string, skip int) { return } if lit.Kind == token.STRING { - if bytes.IndexByte(lit.Value, '%') < 0 { + if !strings.Contains(lit.Value, "%") { if len(call.Args) > skip+1 { f.Badf(call.Pos(), "no formatting directive in %s call", name) } @@ -265,24 +269,64 @@ func (f *File) checkPrintf(call *ast.CallExpr, name string, skip int) { } // Hard part: check formats against args. // Trivial but useful test: count. - numPercent := 0 - for i := 0; i < len(lit.Value); i++ { + numArgs := 0 + for i, w := 0, 0; i < len(lit.Value); i += w { + w = 1 if lit.Value[i] == '%' { - if i+1 < len(lit.Value) && lit.Value[i+1] == '%' { - // %% doesn't count. - i++ - } else { - numPercent++ - } + nbytes, nargs := parsePrintfVerb(lit.Value[i:]) + w = nbytes + numArgs += nargs } } expect := len(call.Args) - (skip + 1) - if numPercent != expect { - f.Badf(call.Pos(), "wrong number of formatting directives in %s call: %d percent(s) for %d args", name, numPercent, expect) + if numArgs != expect { + f.Badf(call.Pos(), "wrong number of args in %s call: %d needed but %d args", name, numArgs, expect) } } -var terminalNewline = []byte(`\n"`) // \n at end of interpreted string +// parsePrintfVerb returns the number of bytes and number of arguments +// consumed by the Printf directive that begins s, including its percent sign +// and verb. +func parsePrintfVerb(s string) (nbytes, nargs int) { + // There's guaranteed a percent sign. + nbytes = 1 + end := len(s) + // There may be flags. +FlagLoop: + for nbytes < end { + switch s[nbytes] { + case '#', '0', '+', '-', ' ': + nbytes++ + default: + break FlagLoop + } + } + getNum := func() { + if nbytes < end && s[nbytes] == '*' { + nbytes++ + nargs++ + } else { + for nbytes < end && '0' <= s[nbytes] && s[nbytes] <= '9' { + nbytes++ + } + } + } + // There may be a width. + getNum() + // If there's a period, there may be a precision. + if nbytes < end && s[nbytes] == '.' { + nbytes++ + getNum() + } + // Now a verb. + c, w := utf8.DecodeRuneInString(s[nbytes:]) + nbytes += w + if c != '%' { + nargs++ + } + return +} + // checkPrint checks a call to an unformatted print routine such as Println. // The skip argument records how many arguments to ignore; that is, @@ -298,7 +342,7 @@ func (f *File) checkPrint(call *ast.CallExpr, name string, skip int) { } arg := args[skip] if lit, ok := arg.(*ast.BasicLit); ok && lit.Kind == token.STRING { - if bytes.IndexByte(lit.Value, '%') >= 0 { + if strings.Contains(lit.Value, "%") { f.Badf(call.Pos(), "possible formatting directive in %s call", name) } } @@ -306,7 +350,7 @@ func (f *File) checkPrint(call *ast.CallExpr, name string, skip int) { // The last item, if a string, should not have a newline. arg = args[len(call.Args)-1] if lit, ok := arg.(*ast.BasicLit); ok && lit.Kind == token.STRING { - if bytes.HasSuffix(lit.Value, terminalNewline) { + if strings.HasSuffix(lit.Value, `\n"`) { f.Badf(call.Pos(), "%s call ends with newline", name) } } @@ -320,8 +364,16 @@ func BadFunctionUsedInTests() { fmt.Println("%s", "hi") // % in call to Println fmt.Printf("%s", "hi", 3) // wrong # percents fmt.Printf("%s%%%d", "hi", 3) // right # percents + fmt.Printf("%.*d", 3, 3) // right # percents, with a * + fmt.Printf("%.*d", 3, 3, 3) // wrong # percents, with a * + printf("now is the time", "buddy") // no %s Printf("now is the time", "buddy") // no %s f := new(File) f.Warn(0, "%s", "hello", 3) // % in call to added function f.Warnf(0, "%s", "hello", 3) // wrong # %s in call to added function } + +// printf is used by the test. +func printf(format string, args ...interface{}) { + panic("don't call - testing only") +} diff --git a/src/cmd/goyacc/goyacc.go b/src/cmd/goyacc/goyacc.go index 32816b700..220c99492 100644 --- a/src/cmd/goyacc/goyacc.go +++ b/src/cmd/goyacc/goyacc.go @@ -382,7 +382,7 @@ outer: for { switch t { default: - error("syntax error tok=%v", t-PRIVATE) + errorf("syntax error tok=%v", t-PRIVATE) case MARK, ENDFILE: break outer @@ -392,14 +392,14 @@ outer: case START: t = gettok() if t != IDENTIFIER { - error("bad %%start construction") + errorf("bad %%start construction") } start = chfind(1, tokname) case TYPEDEF: t = gettok() if t != TYPENAME { - error("bad syntax in %%type") + errorf("bad syntax in %%type") } ty = numbval for { @@ -410,7 +410,7 @@ outer: if t < NTBASE { j = TYPE(toklev[t]) if j != 0 && j != ty { - error("type redeclaration of token ", + errorf("type redeclaration of token ", tokset[t].name) } else { toklev[t] = SETTYPE(toklev[t], ty) @@ -418,7 +418,7 @@ outer: } else { j = nontrst[t-NTBASE].value if j != 0 && j != ty { - error("type redeclaration of nonterminal %v", + errorf("type redeclaration of nonterminal %v", nontrst[t-NTBASE].name) } else { nontrst[t-NTBASE].value = ty @@ -464,18 +464,18 @@ outer: case IDENTIFIER: j = chfind(0, tokname) if j >= NTBASE { - error("%v defined earlier as nonterminal", tokname) + errorf("%v defined earlier as nonterminal", tokname) } if lev != 0 { if ASSOC(toklev[j]) != 0 { - error("redeclaration of precedence of %v", tokname) + errorf("redeclaration of precedence of %v", tokname) } toklev[j] = SETASC(toklev[j], lev) toklev[j] = SETPLEV(toklev[j], i) } if ty != 0 { if TYPE(toklev[j]) != 0 { - error("redeclaration of type of %v", tokname) + errorf("redeclaration of type of %v", tokname) } toklev[j] = SETTYPE(toklev[j], ty) } @@ -498,7 +498,7 @@ outer: } if t == ENDFILE { - error("unexpected EOF before %%") + errorf("unexpected EOF before %%") } // put out non-literal terminals @@ -533,7 +533,7 @@ outer: curprod := make([]int, RULEINC) t = gettok() if t != IDENTCOLON { - error("bad syntax on first rule") + errorf("bad syntax on first rule") } if start == 0 { @@ -557,11 +557,11 @@ outer: } else if t == IDENTCOLON { curprod[mem] = chfind(1, tokname) if curprod[mem] < NTBASE { - error("token illegal on LHS of grammar rule") + errorf("token illegal on LHS of grammar rule") } mem++ } else { - error("illegal rule: missing semicolon or | ?") + errorf("illegal rule: missing semicolon or | ?") } // read rule body @@ -582,11 +582,11 @@ outer: } if t == PREC { if gettok() != IDENTIFIER { - error("illegal %%prec syntax") + errorf("illegal %%prec syntax") } j = chfind(2, tokname) if j >= NTBASE { - error("nonterminal " + nontrst[j-NTBASE].name + " illegal after %%prec") + errorf("nonterminal " + nontrst[j-NTBASE].name + " illegal after %%prec") } levprd[nprod] = toklev[j] t = gettok() @@ -642,7 +642,7 @@ outer: // no explicit action, LHS has value tempty := curprod[1] if tempty < 0 { - error("must return a value, since LHS has a type") + errorf("must return a value, since LHS has a type") } if tempty >= NTBASE { tempty = nontrst[tempty-NTBASE].value @@ -650,7 +650,7 @@ outer: tempty = TYPE(toklev[tempty]) } if tempty != nontrst[curprod[0]-NTBASE].value { - error("default action causes potential type clash") + errorf("default action causes potential type clash") } fmt.Fprintf(fcode, "\ncase %v:", nprod) fmt.Fprintf(fcode, "\n\t%sVAL.%v = %sS[%spt-0].%v;", @@ -773,7 +773,7 @@ func defin(nt int, s string) int { case 'v': val = '\v' default: - error("invalid escape %v", s[1:3]) + errorf("invalid escape %v", s[1:3]) } } else if s[2] == 'u' && len(s) == 2+1+4 { // \unnnn sequence val = 0 @@ -788,16 +788,16 @@ func defin(nt int, s string) int { case c >= 'A' && c <= 'F': c -= 'A' - 10 default: - error("illegal \\unnnn construction") + errorf("illegal \\unnnn construction") } val = val*16 + c s = s[1:] } if val == 0 { - error("'\\u0000' is illegal") + errorf("'\\u0000' is illegal") } } else { - error("unknown escape") + errorf("unknown escape") } } else { val = extval @@ -855,7 +855,7 @@ func gettok() int { } if c != '>' { - error("unterminated < ... > clause") + errorf("unterminated < ... > clause") } for i = 1; i <= ntypes; i++ { @@ -881,7 +881,7 @@ func gettok() int { for { c = getrune(finput) if c == '\n' || c == EOF { - error("illegal or missing ' or \"") + errorf("illegal or missing ' or \"") } if c == '\\' { tokname += string('\\') @@ -926,7 +926,7 @@ func gettok() int { return resrv[c].value } } - error("invalid escape, or illegal reserved word: %v", tokname) + errorf("invalid escape, or illegal reserved word: %v", tokname) case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': numbval = c - '0' @@ -1004,7 +1004,7 @@ func fdtype(t int) int { s = tokset[t].name } if v <= 0 { - error("must specify type for %v", s) + errorf("must specify type for %v", s) } return v } @@ -1026,7 +1026,7 @@ func chfind(t int, s string) int { // cannot find name if t > 1 { - error("%v should have been defined earlier", s) + errorf("%v should have been defined earlier", s) } return defin(t, s) } @@ -1047,7 +1047,7 @@ out: for { c := getrune(finput) if c == EOF { - error("EOF encountered while processing %%union") + errorf("EOF encountered while processing %%union") } ftable.WriteRune(c) switch c { @@ -1097,7 +1097,7 @@ func cpycode() { c = getrune(finput) } lineno = lno - error("eof before %%}") + errorf("eof before %%}") } // @@ -1115,11 +1115,11 @@ func skipcom() int { } c = getrune(finput) } - error("EOF inside comment") + errorf("EOF inside comment") return 0 } if c != '*' { - error("illegal comment") + errorf("illegal comment") } nl := 0 // lines skipped @@ -1196,7 +1196,7 @@ loop: if c == '<' { ungetrune(finput, c) if gettok() != TYPENAME { - error("bad syntax on $<ident> clause") + errorf("bad syntax on $<ident> clause") } tok = numbval c = getrune(finput) @@ -1226,13 +1226,13 @@ loop: ungetrune(finput, c) j = j * s if j >= max { - error("Illegal use of $%v", j) + errorf("Illegal use of $%v", j) } } else if isword(c) || c == '_' || c == '.' { // look for $name ungetrune(finput, c) if gettok() != IDENTIFIER { - error("$ must be followed by an identifier") + errorf("$ must be followed by an identifier") } tokn := chfind(2, tokname) fnd := -1 @@ -1240,7 +1240,7 @@ loop: if c != '@' { ungetrune(finput, c) } else if gettok() != NUMBER { - error("@ must be followed by number") + errorf("@ must be followed by number") } else { fnd = numbval } @@ -1253,7 +1253,7 @@ loop: } } if j >= max { - error("$name or $name@number not found") + errorf("$name or $name@number not found") } } else { fcode.WriteRune('$') @@ -1268,7 +1268,7 @@ loop: // put out the proper tag if ntypes != 0 { if j <= 0 && tok < 0 { - error("must specify type of $%v", j) + errorf("must specify type of $%v", j) } if tok < 0 { tok = fdtype(curprod[j]) @@ -1315,7 +1315,7 @@ loop: fcode.WriteRune(c) c = getrune(finput) } - error("EOF inside comment") + errorf("EOF inside comment") case '\'', '"': // character string or constant @@ -1333,16 +1333,16 @@ loop: break swt } if c == '\n' { - error("newline in string or char const") + errorf("newline in string or char const") } fcode.WriteRune(c) c = getrune(finput) } - error("EOF in string or character constant") + errorf("EOF in string or character constant") case EOF: lineno = lno - error("action does not terminate") + errorf("action does not terminate") case '\n': lineno++ @@ -1356,14 +1356,14 @@ func openup() { infile = flag.Arg(0) finput = open(infile) if finput == nil { - error("cannot open %v", infile) + errorf("cannot open %v", infile) } foutput = nil if vflag != "" { - foutput = create(vflag, 0666) + foutput = create(vflag) if foutput == nil { - error("can't create file %v", vflag) + errorf("can't create file %v", vflag) } } @@ -1371,9 +1371,9 @@ func openup() { if oflag == "" { oflag = "y.go" } - ftable = create(oflag, 0666) + ftable = create(oflag) if ftable == nil { - error("can't create file %v", oflag) + errorf("can't create file %v", oflag) } } @@ -1433,7 +1433,7 @@ func cpres() { } } if n == 0 { - error("nonterminal %v not defined", nontrst[i].name) + errorf("nonterminal %v not defined", nontrst[i].name) continue } pres[i] = make([][]int, n) @@ -1506,7 +1506,7 @@ more: } if pempty[i] != OK { fatfl = 0 - error("nonterminal " + nontrst[i].name + " never derives any token string") + errorf("nonterminal " + nontrst[i].name + " never derives any token string") } } @@ -1921,11 +1921,11 @@ look: // state is new zznewstate++ if nolook != 0 { - error("yacc state/nolook error") + errorf("yacc state/nolook error") } pstate[nstate+2] = p2 if nstate+1 >= NSTATES { - error("too many states") + errorf("too many states") } if c >= NTBASE { mstates[nstate] = ntstates[c-NTBASE] @@ -2061,7 +2061,7 @@ nextk: } return off + rr } - error("no space in action table") + errorf("no space in action table") return 0 } @@ -2623,7 +2623,7 @@ nextgp: if s > maxa { maxa = s if maxa >= ACTSIZE { - error("a array overflow") + errorf("a array overflow") } } if amem[s] != 0 { @@ -2646,7 +2646,7 @@ nextgp: } return } - error("cannot place goto %v\n", i) + errorf("cannot place goto %v\n", i) } func stin(i int) { @@ -2705,7 +2705,7 @@ nextn: maxa = s } if amem[s] != 0 && amem[s] != q[r+1] { - error("clobber of a array, pos'n %v, by %v", s, q[r+1]) + errorf("clobber of a array, pos'n %v, by %v", s, q[r+1]) } amem[s] = q[r+1] } @@ -2715,7 +2715,7 @@ nextn: } return } - error("Error; failure to place state %v", i) + errorf("Error; failure to place state %v", i) } // @@ -3014,7 +3014,7 @@ func getrune(f *bufio.Reader) int { return EOF } if err != nil { - error("read error: %v", err) + errorf("read error: %v", err) } //fmt.Printf("rune = %v n=%v\n", string(c), n); return c @@ -3036,27 +3036,27 @@ func write(f *bufio.Writer, b []byte, n int) int { } func open(s string) *bufio.Reader { - fi, err := os.Open(s, os.O_RDONLY, 0) + fi, err := os.Open(s) if err != nil { - error("error opening %v: %v", s, err) + errorf("error opening %v: %v", s, err) } //fmt.Printf("open %v\n", s); return bufio.NewReader(fi) } -func create(s string, m uint32) *bufio.Writer { - fo, err := os.Open(s, os.O_WRONLY|os.O_CREAT|os.O_TRUNC, m) +func create(s string) *bufio.Writer { + fo, err := os.Create(s) if err != nil { - error("error opening %v: %v", s, err) + errorf("error creating %v: %v", s, err) } - //fmt.Printf("create %v mode %v\n", s, m); + //fmt.Printf("create %v mode %v\n", s); return bufio.NewWriter(fo) } // // write out error comment // -func error(s string, v ...interface{}) { +func errorf(s string, v ...interface{}) { nerrors++ fmt.Fprintf(stderr, s, v...) fmt.Fprintf(stderr, ": %v:%v\n", infile, lineno) diff --git a/src/cmd/goyacc/units.y b/src/cmd/goyacc/units.y index 5d3f9aca2..06ce11693 100644 --- a/src/cmd/goyacc/units.y +++ b/src/cmd/goyacc/units.y @@ -90,7 +90,7 @@ prog: $2.node = $3; $2.node.dim[0] = 1; if f != 0 { - Error("redefinition of %v", $2.name); + Errorf("redefinition of %v", $2.name); } else if vflag { fmt.Printf("%v\t%v\n", $2.name, &$2.node); @@ -106,7 +106,7 @@ prog: } } if i >= Ndim { - Error("too many dimensions"); + Errorf("too many dimensions"); i = Ndim-1; } fund[i] = $2; @@ -116,7 +116,7 @@ prog: $2.node.dim[0] = 1; $2.node.dim[i] = 1; if f != 0 { - Error("redefinition of %v", $2.name); + Errorf("redefinition of %v", $2.name); } else if vflag { fmt.Printf("%v\t#\n", $2.name); @@ -175,7 +175,7 @@ expr2: for i=1; i<Ndim; i++ { if $3.dim[i] != 0 { - Error("exponent has units"); + Errorf("exponent has units"); $$ = $1; break; } @@ -183,7 +183,7 @@ expr2: if i >= Ndim { i = int($3.vval); if float64(i) != $3.vval { - Error("exponent not integral"); + Errorf("exponent not integral"); } xpn(&$$, &$1, i); } @@ -200,7 +200,7 @@ expr0: VAR { if $1.node.dim[0] == 0 { - Error("undefined %v", $1.name); + Errorf("undefined %v", $1.name); $$ = one; } else $$ = $1.node; @@ -284,7 +284,7 @@ numb: } func (UnitsLex) Error(s string) { - Error("syntax error, last name: %v", sym) + Errorf("syntax error, last name: %v", sym) } func main() { @@ -299,7 +299,7 @@ func main() { file = flag.Arg(0) } - f, err := os.Open(file, os.O_RDONLY, 0) + f, err := os.Open(file) if err != nil { fmt.Fprintf(os.Stderr, "error opening %v: %v\n", file, err) os.Exit(1) @@ -391,7 +391,7 @@ func rdigit(c int) bool { return false } -func Error(s string, v ...interface{}) { +func Errorf(s string, v ...interface{}) { fmt.Printf("%v: %v\n\t", lineno, line) fmt.Printf(s, v...) fmt.Printf("\n") @@ -411,7 +411,7 @@ func add(c, a, b *Node) { d = a.dim[i] c.dim[i] = d if d != b.dim[i] { - Error("add must be like units") + Errorf("add must be like units") } } c.vval = fadd(a.vval, b.vval) @@ -425,7 +425,7 @@ func sub(c, a, b *Node) { d = a.dim[i] c.dim[i] = d if d != b.dim[i] { - Error("sub must be like units") + Errorf("sub must be like units") } } c.vval = fadd(a.vval, -b.vval) @@ -711,11 +711,11 @@ func fmul(a, b float64) float64 { } if l > Maxe { - Error("overflow in multiply") + Errorf("overflow in multiply") return 1 } if l < -Maxe { - Error("underflow in multiply") + Errorf("underflow in multiply") return 0 } return a * b @@ -728,7 +728,7 @@ func fdiv(a, b float64) float64 { if b <= 0 { if b == 0 { - Error("division by zero: %v %v", a, b) + Errorf("division by zero: %v %v", a, b) return 1 } l = math.Log(-b) @@ -746,11 +746,11 @@ func fdiv(a, b float64) float64 { } if l < -Maxe { - Error("overflow in divide") + Errorf("overflow in divide") return 1 } if l > Maxe { - Error("underflow in divide") + Errorf("underflow in divide") return 0 } return a / b diff --git a/src/cmd/ld/data.c b/src/cmd/ld/data.c index a20b057ce..d27416dac 100644 --- a/src/cmd/ld/data.c +++ b/src/cmd/ld/data.c @@ -722,7 +722,7 @@ addsize(Sym *s, Sym *t) void dodata(void) { - int32 h, t, datsize; + int32 t, datsize; Section *sect; Sym *s, *last, **l; @@ -732,23 +732,22 @@ dodata(void) last = nil; datap = nil; - for(h=0; h<NHASH; h++) { - for(s=hash[h]; s!=S; s=s->hash){ - if(!s->reachable || s->special) - continue; - if(STEXT < s->type && s->type < SXREF) { - if(last == nil) - datap = s; - else - last->next = s; - s->next = nil; - last = s; - } + + for(s=allsym; s!=S; s=s->allsym) { + if(!s->reachable || s->special) + continue; + if(STEXT < s->type && s->type < SXREF) { + if(last == nil) + datap = s; + else + last->next = s; + s->next = nil; + last = s; } } for(s = datap; s != nil; s = s->next) { - if(s->np > 0 && s->type == SBSS) // TODO: necessary? + if(s->np > 0 && s->type == SBSS) s->type = SDATA; if(s->np > s->size) diag("%s: initialize bounds (%lld < %d)", @@ -786,8 +785,7 @@ dodata(void) s = datap; for(; s != nil && s->type < SDATA; s = s->next) { s->type = SRODATA; - t = rnd(s->size, 4); - s->size = t; + t = rnd(s->size, PtrSize); s->value = datsize; datsize += t; } @@ -834,7 +832,6 @@ dodata(void) datsize = rnd(datsize, 4); else datsize = rnd(datsize, 8); - s->size = t; s->value = datsize; datsize += t; } diff --git a/src/cmd/ld/dwarf.c b/src/cmd/ld/dwarf.c index 5ba4b7c64..fa55fcbb4 100644 --- a/src/cmd/ld/dwarf.c +++ b/src/cmd/ld/dwarf.c @@ -6,7 +6,6 @@ // - eliminate DW_CLS_ if not used // - package info in compilation units // - assign global variables and types to their packages -// - (upstream) type info for C parts of runtime // - gdb uses c syntax, meaning clumsy quoting is needed for go identifiers. eg // ptype struct '[]uint8' and qualifiers need to be quoted away // - lexical scoping is lost, so gdb gets confused as to which 'main.i' you mean. @@ -943,14 +942,16 @@ enum { static char* decodetype_structfieldname(Sym *s, int i) { + Reloc *r; + // go.string."foo" 0x28 / 0x40 s = decode_reloc_sym(s, CommonSize + PtrSize + 2*4 + i*StructFieldSize); if (s == nil) // embedded structs have a nil name. return nil; - s = decode_reloc_sym(s, 0); // string."foo" - if (s == nil) // shouldn't happen. + r = decode_reloc(s, 0); // s has a pointer to the string data at offset 0 + if (r == nil) // shouldn't happen. return nil; - return (char*)s->p; // the c-string + return (char*) r->sym->p + r->add; // the c-string } static Sym* @@ -989,8 +990,8 @@ lookup_or_diag(char *n) { Sym *s; - s = lookup(n, 0); - if (s->size == 0) { + s = rlookup(n, 0); + if (s == nil || s->size == 0) { diag("dwarf: missing type: %s", n); errorexit(); } @@ -1021,22 +1022,8 @@ defgotype(Sym *gotype) if (die != nil) return die; - if (0 && debug['v'] > 2) { - print("new type: %s @0x%08x [%d]", gotype->name, gotype->value, gotype->size); - for (i = 0; i < gotype->size; i++) { - if (!(i%8)) print("\n\t%04x ", i); - print("%02x ", gotype->p[i]); - } - print("\n"); - for (i = 0; i < gotype->nr; i++) { - print("\t0x%02x[%x] %d %s[%llx]\n", - gotype->r[i].off, - gotype->r[i].siz, - gotype->r[i].type, - gotype->r[i].sym->name, - (vlong)gotype->r[i].add); - } - } + if (0 && debug['v'] > 2) + print("new type: %Y\n", gotype); kind = decodetype_kind(gotype); bytesize = decodetype_size(gotype); @@ -1117,7 +1104,6 @@ defgotype(Sym *gotype) fld = newdie(die, DW_ABRV_FUNCTYPEPARAM, s->name+5); newrefattr(fld, DW_AT_type, defptrto(defgotype(s))); } - die = defptrto(die); break; case KindInterface: @@ -1391,7 +1377,7 @@ static void synthesizechantypes(DWDie *die) { DWDie *sudog, *waitq, *link, *hchan, - *dws, *dww, *dwl, *dwh, *elemtype; + *dws, *dww, *dwh, *elemtype; DWAttr *a; int elemsize, linksize, sudogsize; @@ -1430,21 +1416,10 @@ synthesizechantypes(DWDie *die) newattr(dww, DW_AT_byte_size, DW_CLS_CONSTANT, getattr(waitq, DW_AT_byte_size)->value, NULL); - // link<T> - dwl = newdie(&dwtypes, DW_ABRV_STRUCTTYPE, - mkinternaltypename("link", getattr(elemtype, DW_AT_name)->data, NULL)); - copychildren(dwl, link); - substitutetype(dwl, "link", defptrto(dwl)); - substitutetype(dwl, "elem", elemtype); - newattr(dwl, DW_AT_byte_size, DW_CLS_CONSTANT, - linksize + (elemsize > 8 ? elemsize - 8 : 0), NULL); - // hchan<T> dwh = newdie(&dwtypes, DW_ABRV_STRUCTTYPE, mkinternaltypename("hchan", getattr(elemtype, DW_AT_name)->data, NULL)); copychildren(dwh, hchan); - substitutetype(dwh, "senddataq", defptrto(dwl)); - substitutetype(dwh, "recvdataq", defptrto(dwl)); substitutetype(dwh, "recvq", dww); substitutetype(dwh, "sendq", dww); substitutetype(dwh, "free", dws); @@ -1463,12 +1438,8 @@ defdwsymb(Sym* sym, char *s, int t, vlong v, vlong size, int ver, Sym *gotype) if (strncmp(s, "go.string.", 10) == 0) return; - if (strncmp(s, "string.", 7) == 0) - return; - if (strncmp(s, "type._.", 7) == 0) - return; - if (strncmp(s, "type.", 5) == 0) { + if (strncmp(s, "type.", 5) == 0 && strcmp(s, "type.*") != 0) { defgotype(sym); return; } diff --git a/src/cmd/ld/elf.c b/src/cmd/ld/elf.c index d5b0b0311..b0cce4985 100644 --- a/src/cmd/ld/elf.c +++ b/src/cmd/ld/elf.c @@ -336,7 +336,7 @@ void elfdynhash(void) { Sym *s, *sy; - int i, h, nbucket, b; + int i, nbucket, b; uchar *pc; uint32 hc, g; uint32 *chain, *buckets; @@ -367,26 +367,24 @@ elfdynhash(void) } memset(chain, 0, nsym * sizeof(uint32)); memset(buckets, 0, nbucket * sizeof(uint32)); - for(h = 0; h<NHASH; h++) { - for(sy=hash[h]; sy!=S; sy=sy->hash) { - if (sy->dynid <= 0) - continue; - - hc = 0; - name = sy->dynimpname; - if(name == nil) - name = sy->name; - for(pc = (uchar*)name; *pc; pc++) { - hc = (hc<<4) + *pc; - g = hc & 0xf0000000; - hc ^= g >> 24; - hc &= ~g; - } - - b = hc % nbucket; - chain[sy->dynid] = buckets[b]; - buckets[b] = sy->dynid; + for(sy=allsym; sy!=S; sy=sy->allsym) { + if (sy->dynid <= 0) + continue; + + hc = 0; + name = sy->dynimpname; + if(name == nil) + name = sy->name; + for(pc = (uchar*)name; *pc; pc++) { + hc = (hc<<4) + *pc; + g = hc & 0xf0000000; + hc ^= g >> 24; + hc &= ~g; } + + b = hc % nbucket; + chain[sy->dynid] = buckets[b]; + buckets[b] = sy->dynid; } adduint32(s, nbucket); diff --git a/src/cmd/ld/elf.h b/src/cmd/ld/elf.h index df15cb115..b27ae679b 100644 --- a/src/cmd/ld/elf.h +++ b/src/cmd/ld/elf.h @@ -946,10 +946,10 @@ typedef Elf64_Shdr ElfShdr; typedef Elf64_Phdr ElfPhdr; void elfinit(void); -ElfEhdr *getElfEhdr(); +ElfEhdr *getElfEhdr(void); ElfShdr *newElfShstrtab(vlong); ElfShdr *newElfShdr(vlong); -ElfPhdr *newElfPhdr(); +ElfPhdr *newElfPhdr(void); uint32 elfwritehdr(void); uint32 elfwritephdrs(void); uint32 elfwriteshdrs(void); diff --git a/src/cmd/ld/go.c b/src/cmd/ld/go.c index 3c1e230b4..055163d08 100644 --- a/src/cmd/ld/go.c +++ b/src/cmd/ld/go.c @@ -135,6 +135,7 @@ ldpkg(Biobuf *f, char *pkg, int64 len, char *filename, int whence) if(debug['u'] && whence != ArchiveObj && (p0+6 > p1 || memcmp(p0, " safe\n", 6) != 0)) { fprint(2, "%s: load of unsafe package %s\n", argv0, filename); + nerrors++; errorexit(); } if(p0 < p1) { @@ -657,27 +658,25 @@ deadcode(void) else last->next = nil; - for(i=0; i<NHASH; i++) - for(s = hash[i]; s != S; s = s->hash) + for(s = allsym; s != S; s = s->allsym) if(strncmp(s->name, "weak.", 5) == 0) { s->special = 1; // do not lay out in data segment s->reachable = 1; + s->hide = 1; } } void doweak(void) { - int i; Sym *s, *t; // resolve weak references only if // target symbol will be in binary anyway. - for(i=0; i<NHASH; i++) - for(s = hash[i]; s != S; s = s->hash) { + for(s = allsym; s != S; s = s->allsym) { if(strncmp(s->name, "weak.", 5) == 0) { - t = lookup(s->name+5, s->version); - if(t->type != 0 && t->reachable) { + t = rlookup(s->name+5, s->version); + if(t && t->type != 0 && t->reachable) { s->value = t->value; s->type = t->type; } else { diff --git a/src/cmd/ld/ldmacho.c b/src/cmd/ld/ldmacho.c index 7e38db0e4..bbb21d51a 100644 --- a/src/cmd/ld/ldmacho.c +++ b/src/cmd/ld/ldmacho.c @@ -422,6 +422,7 @@ void ldmacho(Biobuf *f, char *pkg, int64 len, char *pn) { int i, j, is64; + uint64 secaddr; uchar hdr[7*4], *cmdp; uchar tmp[4]; uchar *dat; @@ -581,7 +582,11 @@ ldmacho(Biobuf *f, char *pkg, int64 len, char *pn) else s->type = SRODATA; } else { - s->type = SDATA; + if (strcmp(sect->name, "__bss") == 0) { + s->type = SBSS; + s->np = 0; + } else + s->type = SDATA; } if(s->type == STEXT) { if(etextp) @@ -751,8 +756,29 @@ ldmacho(Biobuf *f, char *pkg, int64 len, char *pn) rp->siz = rel->length; rp->type = 512 + (rel->type<<1) + rel->pcrel; rp->off = rel->addr; - - rp->add = e->e32(s->p+rp->off); + + // Handle X86_64_RELOC_SIGNED referencing a section (rel->extrn == 0). + if (thechar == '6' && rel->extrn == 0 && rel->type == 1) { + // Calculate the addend as the offset into the section. + // + // The rip-relative offset stored in the object file is encoded + // as follows: + // + // movsd 0x00000360(%rip),%xmm0 + // + // To get the absolute address of the value this rip-relative address is pointing + // to, we must add the address of the next instruction to it. This is done by + // taking the address of the relocation and adding 4 to it (since the rip-relative + // offset can at most be 32 bits long). To calculate the offset into the section the + // relocation is referencing, we subtract the vaddr of the start of the referenced + // section found in the original object file. + // + // [For future reference, see Darwin's /usr/include/mach-o/x86_64/reloc.h] + secaddr = c->seg.sect[rel->symnum-1].addr; + rp->add = e->e32(s->p+rp->off) + rp->off + 4 - secaddr; + } else + rp->add = e->e32(s->p+rp->off); + // For i386 Mach-O PC-relative, the addend is written such that // it *is* the PC being subtracted. Use that to make // it match our version of PC-relative. diff --git a/src/cmd/ld/lib.c b/src/cmd/ld/lib.c index e645502b3..8cd570463 100644 --- a/src/cmd/ld/lib.c +++ b/src/cmd/ld/lib.c @@ -61,6 +61,7 @@ void libinit(void) { fmtinstall('i', iconv); + fmtinstall('Y', Yconv); mywhatsys(); // get goroot, goarch, goos if(strcmp(goarch, thestring) != 0) print("goarch is not known: %s\n", goarch); @@ -470,8 +471,8 @@ eof: diag("truncated object file: %s", pn); } -Sym* -lookup(char *symb, int v) +static Sym* +_lookup(char *symb, int v, int creat) { Sym *s; char *p; @@ -485,10 +486,12 @@ lookup(char *symb, int v) // not if(h < 0) h = ~h, because gcc 4.3 -O2 miscompiles it. h &= 0xffffff; h %= NHASH; + c = symb[0]; for(s = hash[h]; s != S; s = s->hash) - if(s->version == v) if(memcmp(s->name, symb, l) == 0) return s; + if(!creat) + return nil; s = mal(sizeof(*s)); if(debug['v'] > 1) @@ -508,9 +511,25 @@ lookup(char *symb, int v) s->size = 0; hash[h] = s; nsymbol++; + + s->allsym = allsym; + allsym = s; return s; } +Sym* +lookup(char *name, int v) +{ + return _lookup(name, v, 1); +} + +// read-only lookup +Sym* +rlookup(char *name, int v) +{ + return _lookup(name, v, 0); +} + void copyhistfrog(char *buf, int nbuf) { @@ -829,7 +848,7 @@ unmal(void *v, uint32 n) // Copied from ../gc/subr.c:/^pathtoprefix; must stay in sync. /* * Convert raw string to the prefix that will be used in the symbol table. - * Invalid bytes turn into %xx. Right now the only bytes that need + * Invalid bytes turn into %xx. Right now the only bytes that need * escaping are %, ., and ", but we escape all control characters too. */ static char* @@ -1104,7 +1123,7 @@ static Sym *newstack; enum { HasLinkRegister = (thechar == '5'), - CallSize = (!HasLinkRegister)*PtrSize, // bytes of stack required for a call + CallSize = (!HasLinkRegister)*PtrSize, // bytes of stack required for a call }; void @@ -1130,7 +1149,7 @@ dostkcheck(void) // Check calling contexts. // Some nosplits get called a little further down, - // like newproc and deferproc. We could hard-code + // like newproc and deferproc. We could hard-code // that knowledge but it's more robust to look at // the actual call sites. for(s = textp; s != nil; s = s->next) { @@ -1283,11 +1302,44 @@ headtype(char *name) void undef(void) { - int i; Sym *s; - for(i=0; i<NHASH; i++) - for(s = hash[i]; s != S; s = s->hash) + for(s = allsym; s != S; s = s->allsym) if(s->type == SXREF) diag("%s(%d): not defined", s->name, s->version); } + +int +Yconv(Fmt *fp) +{ + Sym *s; + Fmt fmt; + int i; + char *str; + + s = va_arg(fp->args, Sym*); + if (s == S) { + fmtprint(fp, "<nil>"); + } else { + fmtstrinit(&fmt); + fmtprint(&fmt, "%s @0x%08x [%d]", s->name, s->value, s->size); + for (i = 0; i < s->size; i++) { + if (!(i%8)) fmtprint(&fmt, "\n\t0x%04x ", i); + fmtprint(&fmt, "%02x ", s->p[i]); + } + fmtprint(&fmt, "\n"); + for (i = 0; i < s->nr; i++) { + fmtprint(&fmt, "\t0x%04x[%x] %d %s[%llx]\n", + s->r[i].off, + s->r[i].siz, + s->r[i].type, + s->r[i].sym->name, + (vlong)s->r[i].add); + } + str = fmtstrflush(&fmt); + fmtstrcpy(fp, str); + free(str); + } + + return 0; +} diff --git a/src/cmd/ld/lib.h b/src/cmd/ld/lib.h index adde2c9ff..646aeb535 100644 --- a/src/cmd/ld/lib.h +++ b/src/cmd/ld/lib.h @@ -28,8 +28,37 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -// Where symbol table data gets mapped into memory. -#define SYMDATVA 0x99LL<<24 +enum +{ + Sxxx, + + /* order here is order in output file */ + STEXT, + SELFDATA, + SMACHOPLT, + STYPE, + SSTRING, + SGOSTRING, + SRODATA, + SDATA, + SMACHO, /* Mach-O __nl_symbol_ptr */ + SMACHOGOT, + SWINDOWS, + SBSS, + + SXREF, + SMACHODYNSTR, + SMACHODYNSYM, + SMACHOINDIRECTPLT, + SMACHOINDIRECTGOT, + SFILE, + SCONST, + SDYNIMPORT, + + SSUB = 1<<8, /* sub-symbol, linked from parent via ->sub list */ + + NHASH = 100003, +}; typedef struct Library Library; struct Library @@ -79,6 +108,7 @@ EXTERN Library* library; EXTERN int libraryp; EXTERN int nlibrary; EXTERN Sym* hash[NHASH]; +EXTERN Sym* allsym; EXTERN Sym* histfrog[MAXHIST]; EXTERN uchar fnuxi8[8]; EXTERN uchar fnuxi4[4]; @@ -106,6 +136,7 @@ void asmlc(void); void histtoauto(void); void collapsefrog(Sym *s); Sym* lookup(char *symb, int v); +Sym* rlookup(char *symb, int v); void nuxiinit(void); int find1(int32 l, int c); int find2(int32 l, int c); @@ -243,3 +274,6 @@ EXTERN char* headstring; extern Header headers[]; int headtype(char*); + +int Yconv(Fmt*); +#pragma varargck type "Y" Sym* diff --git a/src/cmd/ld/pe.c b/src/cmd/ld/pe.c index e72b0b2a0..0d4240e36 100644 --- a/src/cmd/ld/pe.c +++ b/src/cmd/ld/pe.c @@ -171,12 +171,10 @@ initdynimport(void) Imp *m; Dll *d; Sym *s, *dynamic; - int i; dr = nil; - for(i=0; i<NHASH; i++) - for(s = hash[i]; s != S; s = s->hash) { + for(s = allsym; s != S; s = s->allsym) { if(!s->reachable || !s->dynimpname || s->dynexport) continue; for(d = dr; d != nil; d = d->next) { @@ -312,12 +310,10 @@ scmp(const void *p1, const void *p2) static void initdynexport(void) { - int i; Sym *s; nexport = 0; - for(i=0; i<NHASH; i++) - for(s = hash[i]; s != S; s = s->hash) { + for(s = allsym; s != S; s = s->allsym) { if(!s->reachable || !s->dynimpname || !s->dynexport) continue; if(nexport+1 > sizeof(dexport)/sizeof(dexport[0])) { diff --git a/src/cmd/ld/symtab.c b/src/cmd/ld/symtab.c index 22777b6b5..aefe0b1af 100644 --- a/src/cmd/ld/symtab.c +++ b/src/cmd/ld/symtab.c @@ -340,6 +340,8 @@ putsymb(Sym *s, char *name, int t, vlong v, vlong size, int ver, Sym *typ) void symtab(void) { + Sym *s; + // Define these so that they'll get put into the symbol table. // data.c:/^address will provide the actual values. xdefine("text", STEXT, 0); @@ -351,11 +353,39 @@ symtab(void) xdefine("end", SBSS, 0); xdefine("epclntab", SRODATA, 0); xdefine("esymtab", SRODATA, 0); + + // pseudo-symbols to mark locations of type, string, and go string data. + s = lookup("type.*", 0); + s->type = STYPE; + s->size = 0; + s->reachable = 1; + + s = lookup("go.string.*", 0); + s->type = SGOSTRING; + s->size = 0; + s->reachable = 1; symt = lookup("symtab", 0); symt->type = SRODATA; symt->size = 0; symt->reachable = 1; + + // assign specific types so that they sort together. + // within a type they sort by size, so the .* symbols + // just defined above will be first. + // hide the specific symbols. + for(s = allsym; s != S; s = s->allsym) { + if(!s->reachable || s->special || s->type != SRODATA) + continue; + if(strncmp(s->name, "type.", 5) == 0) { + s->type = STYPE; + s->hide = 1; + } + if(strncmp(s->name, "go.string.", 10) == 0) { + s->type = SGOSTRING; + s->hide = 1; + } + } genasmsym(putsymb); } diff --git a/src/cmd/nm/doc.go b/src/cmd/nm/doc.go index 84a91792f..2a37dd835 100644 --- a/src/cmd/nm/doc.go +++ b/src/cmd/nm/doc.go @@ -11,6 +11,9 @@ Nm is a version of the Plan 9 nm command. The original is documented at It prints the name list (symbol table) for programs compiled by gc as well as the Plan 9 C compiler. +This implementation adds the flag -S, which prints each symbol's size +in decimal after its address. + For reasons of disambiguation it is installed as 6nm although it also serves as an 8nm and a 5nm. diff --git a/src/cmd/prof/gopprof b/src/cmd/prof/gopprof index 4bcfa5800..8fa00cbe8 100755 --- a/src/cmd/prof/gopprof +++ b/src/cmd/prof/gopprof @@ -1896,6 +1896,7 @@ sub SvgJavascript { // SVGPan // http://www.cyberz.org/blog/2009/12/08/svgpan-a-javascript-svg-panzoomdrag-library/ // Local modification: if(true || ...) below to force panning, never moving. +// Local modification: add clamping to fix bug in handleMouseWheel. /** * SVGPan library 1.2 @@ -2038,6 +2039,15 @@ function handleMouseWheel(evt) { var z = 1 + delta; // Zoom factor: 0.9/1.1 + // Clamp to reasonable values. + // The 0.1 check is important because + // a very large scroll can turn into a + // negative z, which rotates the image 180 degrees. + if(z < 0.1) + z = 0.1; + if(z > 10.0) + z = 10.0; + var g = svgDoc.getElementById("viewport"); var p = getEventPoint(evt); @@ -2416,11 +2426,17 @@ sub RemoveUninterestingFrames { 'makechan', 'makemap', 'mal', - 'mallocgc', + 'runtime.new', + 'makeslice1', + 'runtime.gostringsize', + 'runtime.malloc', + 'unsafe.New', + 'runtime.mallocgc', 'runtime.catstring', 'runtime.ifaceT2E', 'runtime.ifaceT2I', 'runtime.makechan', + 'runtime.makechan_c', 'runtime.makemap', 'runtime.makeslice', 'runtime.mal', @@ -2460,7 +2476,8 @@ sub RemoveUninterestingFrames { # Nothing skipped for unknown types } - if ($main::profile_type eq 'cpu') { + # Go doesn't have the problem that this heuristic tries to fix. Disable. + if (0 && $main::profile_type eq 'cpu') { # If all the second-youngest program counters are the same, # this STRONGLY suggests that it is an artifact of measurement, # i.e., stack frames pushed by the CPU profiler signal handler. @@ -2960,7 +2977,7 @@ sub FetchDynamicProfile { $url = sprintf("http://$profile_name" . "seconds=%d", $main::opt_seconds); } - $curl_timeout = sprintf("--max-time=%d", + $curl_timeout = sprintf("--max-time %d", int($main::opt_seconds * 1.01 + 60)); } else { # For non-CPU profiles, we add a type-extension to |