riff: fix some short chunk data bugs.

Fixes golang/go#16236

Change-Id: I0e524054d0702a6487ff47d86aed6bf58f4ba3f2
Reviewed-on: https://go-review.googlesource.com/24638
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Nigel Tao 2016-07-01 16:56:48 +10:00
parent 8550bb5380
commit a21e5be7b4
2 changed files with 86 additions and 3 deletions

View File

@ -23,6 +23,7 @@ import (
var ( var (
errMissingPaddingByte = errors.New("riff: missing padding byte") errMissingPaddingByte = errors.New("riff: missing padding byte")
errMissingRIFFChunkHeader = errors.New("riff: missing RIFF chunk header") errMissingRIFFChunkHeader = errors.New("riff: missing RIFF chunk header")
errListSubchunkTooLong = errors.New("riff: list subchunk too long")
errShortChunkData = errors.New("riff: short chunk data") errShortChunkData = errors.New("riff: short chunk data")
errShortChunkHeader = errors.New("riff: short chunk header") errShortChunkHeader = errors.New("riff: short chunk header")
errStaleReader = errors.New("riff: stale reader") errStaleReader = errors.New("riff: stale reader")
@ -100,13 +101,23 @@ func (z *Reader) Next() (chunkID FourCC, chunkLen uint32, chunkData io.Reader, e
// Drain the rest of the previous chunk. // Drain the rest of the previous chunk.
if z.chunkLen != 0 { if z.chunkLen != 0 {
_, z.err = io.Copy(ioutil.Discard, z.chunkReader) want := z.chunkLen
var got int64
got, z.err = io.Copy(ioutil.Discard, z.chunkReader)
if z.err == nil && uint32(got) != want {
z.err = errShortChunkData
}
if z.err != nil { if z.err != nil {
return FourCC{}, 0, nil, z.err return FourCC{}, 0, nil, z.err
} }
} }
z.chunkReader = nil z.chunkReader = nil
if z.padded { if z.padded {
if z.totalLen == 0 {
z.err = errListSubchunkTooLong
return FourCC{}, 0, nil, z.err
}
z.totalLen--
_, z.err = io.ReadFull(z.r, z.buf[:1]) _, z.err = io.ReadFull(z.r, z.buf[:1])
if z.err != nil { if z.err != nil {
if z.err == io.EOF { if z.err == io.EOF {
@ -114,7 +125,6 @@ func (z *Reader) Next() (chunkID FourCC, chunkLen uint32, chunkData io.Reader, e
} }
return FourCC{}, 0, nil, z.err return FourCC{}, 0, nil, z.err
} }
z.totalLen--
} }
// We are done if we have no more data. // We are done if we have no more data.
@ -129,7 +139,7 @@ func (z *Reader) Next() (chunkID FourCC, chunkLen uint32, chunkData io.Reader, e
return FourCC{}, 0, nil, z.err return FourCC{}, 0, nil, z.err
} }
z.totalLen -= chunkHeaderSize z.totalLen -= chunkHeaderSize
if _, err = io.ReadFull(z.r, z.buf[:chunkHeaderSize]); err != nil { if _, z.err = io.ReadFull(z.r, z.buf[:chunkHeaderSize]); z.err != nil {
if z.err == io.EOF || z.err == io.ErrUnexpectedEOF { if z.err == io.EOF || z.err == io.ErrUnexpectedEOF {
z.err = errShortChunkHeader z.err = errShortChunkHeader
} }
@ -137,6 +147,10 @@ func (z *Reader) Next() (chunkID FourCC, chunkLen uint32, chunkData io.Reader, e
} }
chunkID = FourCC{z.buf[0], z.buf[1], z.buf[2], z.buf[3]} chunkID = FourCC{z.buf[0], z.buf[1], z.buf[2], z.buf[3]}
z.chunkLen = u32(z.buf[4:]) z.chunkLen = u32(z.buf[4:])
if z.chunkLen > z.totalLen {
z.err = errListSubchunkTooLong
return FourCC{}, 0, nil, z.err
}
z.padded = z.chunkLen&1 == 1 z.padded = z.chunkLen&1 == 1
z.chunkReader = &chunkReader{z} z.chunkReader = &chunkReader{z}
return chunkID, z.chunkLen, z.chunkReader, nil return chunkID, z.chunkLen, z.chunkReader, nil

69
riff/riff_test.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package riff
import (
"bytes"
"testing"
)
func encodeU32(u uint32) []byte {
return []byte{
byte(u >> 0),
byte(u >> 8),
byte(u >> 16),
byte(u >> 24),
}
}
func TestShortChunks(t *testing.T) {
// s is a RIFF(ABCD) with allegedly 256 bytes of data (excluding the
// leading 8-byte "RIFF\x00\x01\x00\x00"). The first chunk of that ABCD
// list is an abcd chunk of length m followed by n zeroes.
for _, m := range []uint32{0, 8, 15, 200, 300} {
for _, n := range []int{0, 1, 2, 7} {
s := []byte("RIFF\x00\x01\x00\x00ABCDabcd")
s = append(s, encodeU32(m)...)
s = append(s, make([]byte, n)...)
_, r, err := NewReader(bytes.NewReader(s))
if err != nil {
t.Errorf("m=%d, n=%d: NewReader: %v", m, n, err)
continue
}
_, _, _, err0 := r.Next()
// The total "ABCD" list length is 256 bytes, of which the first 12
// bytes are "ABCDabcd" plus the 4-byte encoding of m. If the
// "abcd" subchunk length (m) plus those 12 bytes is greater than
// the total list length, we have an invalid RIFF, and we expect an
// errListSubchunkTooLong error.
if m+12 > 256 {
if err0 != errListSubchunkTooLong {
t.Errorf("m=%d, n=%d: Next #0: got %v, want %v", m, n, err0, errListSubchunkTooLong)
}
continue
}
// Otherwise, we expect a nil error.
if err0 != nil {
t.Errorf("m=%d, n=%d: Next #0: %v", m, n, err0)
continue
}
_, _, _, err1 := r.Next()
// If m > 0, then m > n, so that "abcd" subchunk doesn't have m
// bytes of data. If m == 0, then that "abcd" subchunk is OK in
// that it has 0 extra bytes of data, but the next subchunk (8 byte
// header plus body) is missing, as we only have n < 8 more bytes.
want := errShortChunkData
if m == 0 {
want = errShortChunkHeader
}
if err1 != want {
t.Errorf("m=%d, n=%d: Next #1: got %v, want %v", m, n, err1, want)
continue
}
}
}
}