Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/next/51115.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pkg io, type LimitedReader struct, Err error #51115
4 changes: 4 additions & 0 deletions doc/next/6-stdlib/99-minor/io/51115.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<!-- go.dev/issue/51115 -->
The new [LimitedReader.Err] field allows returning a custom error
when the read limit is exceeded. When nil (the default), [LimitedReader.Read]
returns [EOF] when the limit is reached, maintaining backward compatibility.
22 changes: 22 additions & 0 deletions src/io/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io_test

import (
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -121,6 +122,27 @@ func ExampleLimitReader() {
// some
}

func ExampleLimitedReader_Err() {
r := strings.NewReader("some io.Reader stream to be read\n")
sentinel := errors.New("read limit reached")
lr := &io.LimitedReader{R: r, N: 4, Err: sentinel}

buf := make([]byte, 10)
n, err := lr.Read(buf)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%d; %q\n", n, buf[:n])

// try to read more and get the custom error
n, err = lr.Read(buf)
fmt.Printf("%d; error: %v\n", n, err)

// Output:
// 4; "some"
// 0; error: read limit reached
}

func ExampleMultiReader() {
r1 := strings.NewReader("first reader ")
r2 := strings.NewReader("second reader ")
Expand Down
62 changes: 51 additions & 11 deletions src/io/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,28 +457,68 @@ func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {

// LimitReader returns a Reader that reads from r
// but stops with EOF after n bytes.
// The underlying implementation is a *LimitedReader.
func LimitReader(r Reader, n int64) Reader { return &LimitedReader{r, n} }
// To return a custom error when the limit is reached, construct
// a *LimitedReader directly with the desired Err field.
func LimitReader(r Reader, n int64) Reader { return &LimitedReader{R: r, N: n} }

// A LimitedReader reads from R but limits the amount of
// data returned to just N bytes. Each call to Read
// updates N to reflect the new amount remaining.
// Read returns EOF when N <= 0 or when the underlying R returns EOF.
//
// Negative values of N mean that the limit has been exceeded.
// Read returns Err when more than N bytes are read from R.
// If Err is nil, Read returns EOF.
type LimitedReader struct {
R Reader // underlying reader
N int64 // max bytes remaining
R Reader // underlying reader
N int64 // max bytes remaining
Err error // error to return when limit is exceeded
}

func (l *LimitedReader) Read(p []byte) (n int, err error) {
if l.N <= 0 {
// We use negative l.N values to signal that we've exceeded the limit and cached the result.
const sentinelExceeded = -1 // Probed and found more data available
const sentinelExactMatch = -2 // Probed and found EOF (exactly N bytes)

if len(p) == 0 && l.N <= 0 {
return 0, nil
}

if l.N > 0 {
if int64(len(p)) > l.N {
p = p[0:l.N]
}
n, err = l.R.Read(p)
l.N -= int64(n)
return
}

// Sentinel
if l.N < 0 {
if l.N == sentinelExceeded && l.Err != nil {
return 0, l.Err
}

return 0, EOF
}

// At limit (N == 0) - need to determine if stream has more data

if l.Err == nil {
return 0, EOF
}
if int64(len(p)) > l.N {
p = p[0:l.N]

// Probe with one byte to distinguish two sentinels.
// We can't tell without reading ahead. This probe permanently consumes
// a byte from R, so we cache the result in N to avoid re-probing.
var probe [1]byte
probeN, probeErr := l.R.Read(probe[:])
if probeN > 0 || probeErr != EOF {
l.N = sentinelExceeded
return 0, l.Err
}
n, err = l.R.Read(p)
l.N -= int64(n)
return

l.N = sentinelExactMatch
return 0, EOF
}

// NewSectionReader returns a [SectionReader] that reads from r
Expand Down
212 changes: 212 additions & 0 deletions src/io/io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,215 @@ func TestOffsetWriter_Write(t *testing.T) {
checkContent(name, f)
})
}

var errLimit = errors.New("limit exceeded")

func TestLimitedReader(t *testing.T) {
src := strings.NewReader("abc")
r := LimitReader(src, 5)
lr, ok := r.(*LimitedReader)
if !ok {
t.Fatalf("LimitReader should return *LimitedReader, got %T", r)
}
if lr.R != src || lr.N != 5 || lr.Err != nil {
t.Fatalf("LimitReader() = {R: %v, N: %d, Err: %v}, want {R: %v, N: 5, Err: nil}", lr.R, lr.N, lr.Err, src)
}

t.Run("WithoutCustomErr", func(t *testing.T) {
tests := []struct {
name string
data string
limit int64
want1N int
want1E error
want2E error
}{
{"UnderLimit", "hello", 10, 5, nil, EOF},
{"ExactLimit", "hello", 5, 5, nil, EOF},
{"OverLimit", "hello world", 5, 5, nil, EOF},
{"ZeroLimit", "hello", 0, 0, EOF, EOF},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit}
buf := make([]byte, 10)

n, err := lr.Read(buf)
if n != tt.want1N || err != tt.want1E {
t.Errorf("first Read() = (%d, %v), want (%d, %v)", n, err, tt.want1N, tt.want1E)
}

n, err = lr.Read(buf)
if n != 0 || err != tt.want2E {
t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.want2E)
}
})
}
})

t.Run("WithCustomErr", func(t *testing.T) {
tests := []struct {
name string
data string
limit int64
err error
wantFirst string
wantErr1 error
wantErr2 error
}{
{"ExactLimit", "hello", 5, errLimit, "hello", nil, EOF},
{"OverLimit", "hello world", 5, errLimit, "hello", nil, errLimit},
{"UnderLimit", "hi", 5, errLimit, "hi", nil, EOF},
{"ZeroLimitEmpty", "", 0, errLimit, "", EOF, EOF},
{"ZeroLimitNonEmpty", "hello", 0, errLimit, "", errLimit, errLimit},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lr := &LimitedReader{R: strings.NewReader(tt.data), N: tt.limit, Err: tt.err}
buf := make([]byte, 10)

n, err := lr.Read(buf)
if n != len(tt.wantFirst) || string(buf[:n]) != tt.wantFirst || err != tt.wantErr1 {
t.Errorf("first Read() = (%d, %q, %v), want (%d, %q, %v)", n, buf[:n], err, len(tt.wantFirst), tt.wantFirst, tt.wantErr1)
}

n, err = lr.Read(buf)
if n != 0 || err != tt.wantErr2 {
t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, tt.wantErr2)
}
})
}
})

t.Run("CustomErrPersists", func(t *testing.T) {
lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: errLimit}
buf := make([]byte, 10)

n, err := lr.Read(buf)
if n != 5 || err != nil || string(buf[:5]) != "hello" {
t.Errorf(`Read() = (%d, %v, %q), want (5, nil, "hello")`, n, err, buf[:5])
}

n, err = lr.Read(buf)
if n != 0 || err != errLimit {
t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err)
}

n, err = lr.Read(buf)
if n != 0 || err != errLimit {
t.Errorf("Read() = (%d, %v), want (0, errLimit)", n, err)
}
})

t.Run("ErrEOF", func(t *testing.T) {
lr := &LimitedReader{R: strings.NewReader("hello world"), N: 5, Err: EOF}
buf := make([]byte, 10)

n, err := lr.Read(buf)
if n != 5 || err != nil {
t.Errorf("Read() = (%d, %v), want (5, nil)", n, err)
}

n, err = lr.Read(buf)
if n != 0 || err != EOF {
t.Errorf("Read() = (%d, %v), want (0, EOF)", n, err)
}
})

t.Run("NoSideEffects", func(t *testing.T) {
lr := &LimitedReader{R: strings.NewReader("hello"), N: 5, Err: errLimit}
buf := make([]byte, 0)

for i := 0; i < 3; i++ {
n, err := lr.Read(buf)
if n != 0 || err != nil {
t.Errorf("zero-length read #%d = (%d, %v), want (0, nil)", i+1, n, err)
}
if lr.N != 5 {
t.Errorf("N after zero-length read #%d = %d, want 5", i+1, lr.N)
}
}

buf = make([]byte, 10)
n, err := lr.Read(buf)
if n != 5 || string(buf[:5]) != "hello" || err != nil {
t.Errorf(`normal Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err)
}
})
}

type errorReader struct {
data []byte
pos int
err error
}

func (r *errorReader) Read(p []byte) (int, error) {
if r.pos >= len(r.data) {
return 0, r.err
}
n := copy(p, r.data[r.pos:])
r.pos += n
return n, nil
}

func TestLimitedReaderErrors(t *testing.T) {
t.Run("UnderlyingError", func(t *testing.T) {
underlyingErr := errors.New("boom")
lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: underlyingErr}, N: 10}
buf := make([]byte, 10)

n, err := lr.Read(buf)
if n != 5 || string(buf[:5]) != "hello" || err != nil {
t.Errorf(`first Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err)
}

n, err = lr.Read(buf)
if n != 0 || err != underlyingErr {
t.Errorf("second Read() = (%d, %v), want (0, %v)", n, err, underlyingErr)
}
})

t.Run("SentinelMasksProbeError", func(t *testing.T) {
probeErr := errors.New("probe failed")
lr := &LimitedReader{R: &errorReader{data: []byte("hello"), err: probeErr}, N: 5, Err: errLimit}
buf := make([]byte, 10)

n, err := lr.Read(buf)
if n != 5 || string(buf[:5]) != "hello" || err != nil {
t.Errorf(`first Read() = (%d, %q, %v), want (5, "hello", nil)`, n, buf[:5], err)
}

n, err = lr.Read(buf)
if n != 0 || err != errLimit {
t.Errorf("second Read() = (%d, %v), want (0, errLimit)", n, err)
}
})
}

func TestLimitedReaderCopy(t *testing.T) {
tests := []struct {
name string
input string
limit int64
wantN int64
wantErr error
}{
{"Exact", "hello", 5, 5, nil},
{"Under", "hi", 5, 2, nil},
{"Over", "hello world", 5, 5, errLimit},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lr := &LimitedReader{R: strings.NewReader(tt.input), N: tt.limit, Err: errLimit}
var dst Buffer
n, err := Copy(&dst, lr)
if n != tt.wantN || err != tt.wantErr {
t.Errorf("Copy() = (%d, %v), want (%d, %v)", n, err, tt.wantN, tt.wantErr)
}
})
}
}