diff options
Diffstat (limited to 'src/pkg/encoding/base64/base64.go')
-rw-r--r-- | src/pkg/encoding/base64/base64.go | 39 |
1 files changed, 24 insertions, 15 deletions
diff --git a/src/pkg/encoding/base64/base64.go b/src/pkg/encoding/base64/base64.go index 85e398fd0..e38c26d0e 100644 --- a/src/pkg/encoding/base64/base64.go +++ b/src/pkg/encoding/base64/base64.go @@ -159,13 +159,11 @@ func (e *encoder) Write(p []byte) (n int, err error) { nn := len(e.out) / 4 * 3 if nn > len(p) { nn = len(p) + nn -= nn % 3 } - nn -= nn % 3 - if nn > 0 { - e.enc.Encode(e.out[0:], p[0:nn]) - if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil { - return n, e.err - } + e.enc.Encode(e.out[0:], p[0:nn]) + if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil { + return n, e.err } n += nn p = p[nn:] @@ -226,21 +224,33 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { var dbuf [4]byte dlen := 4 - for j := 0; j < 4; { + for j := range dbuf { if len(src) == 0 { return n, false, CorruptInputError(olen - len(src) - j) } in := src[0] src = src[1:] - if in == '=' && j >= 2 && len(src) < 4 { + if in == '=' { // We've reached the end and there's padding - if len(src)+j < 4-1 { - // not enough padding - return n, false, CorruptInputError(olen) - } - if len(src) > 0 && src[0] != '=' { + switch j { + case 0, 1: // incorrect padding return n, false, CorruptInputError(olen - len(src) - 1) + case 2: + // "==" is expected, the first "=" is already consumed. + if len(src) == 0 { + // not enough padding + return n, false, CorruptInputError(olen) + } + if src[0] != '=' { + // incorrect padding + return n, false, CorruptInputError(olen - len(src) - 1) + } + src = src[1:] + } + if len(src) > 0 { + // trailing garbage + err = CorruptInputError(olen - len(src)) } dlen, end = j, true break @@ -249,7 +259,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { if dbuf[j] == 0xFF { return n, false, CorruptInputError(olen - len(src) - 1) } - j++ } // Pack 4x 6-bit source blocks into 3 byte destination @@ -268,7 +277,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { n += dlen - 1 } - return n, end, nil + return n, end, err } // Decode decodes src using the encoding enc. It writes at most |