-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
49c60d6
commit 615aaa0
Showing
2 changed files
with
374 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright 2023 igevin | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package queue | ||
|
||
import ( | ||
"github.com/igevin/algokit/collection/queue" | ||
"github.com/igevin/algokit/comparator" | ||
"sync" | ||
) | ||
|
||
type ConcurrentPriorityQueue[T any] struct { | ||
pq queue.PriorityQueue[T] | ||
m sync.RWMutex | ||
} | ||
|
||
func (c *ConcurrentPriorityQueue[T]) Len() int { | ||
c.m.RLock() | ||
defer c.m.RUnlock() | ||
return c.pq.Len() | ||
} | ||
|
||
// Cap 无界队列返回0,有界队列返回创建队列时设置的值 | ||
func (c *ConcurrentPriorityQueue[T]) Cap() int { | ||
c.m.RLock() | ||
defer c.m.RUnlock() | ||
return c.pq.Cap() | ||
} | ||
|
||
func (c *ConcurrentPriorityQueue[T]) Peek() (T, error) { | ||
c.m.RLock() | ||
defer c.m.RUnlock() | ||
return c.pq.Peek() | ||
} | ||
|
||
func (c *ConcurrentPriorityQueue[T]) Enqueue(t T) error { | ||
c.m.Lock() | ||
defer c.m.Unlock() | ||
return c.pq.Enqueue(t) | ||
} | ||
|
||
func (c *ConcurrentPriorityQueue[T]) Dequeue() (T, error) { | ||
c.m.Lock() | ||
defer c.m.Unlock() | ||
return c.pq.Dequeue() | ||
} | ||
|
||
// NewConcurrentPriorityQueue 创建优先队列 capacity <= 0 时,为无界队列 | ||
func NewConcurrentPriorityQueue[T any](capacity int, compare comparator.Compare[T]) *ConcurrentPriorityQueue[T] { | ||
return &ConcurrentPriorityQueue[T]{ | ||
pq: *queue.NewPriorityQueue[T](capacity, compare), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
// Copyright 2023 igevin | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package queue | ||
|
||
import ( | ||
"fmt" | ||
"github.com/igevin/algokit/collection/queue" | ||
"github.com/igevin/algokit/comparator" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"sync" | ||
"testing" | ||
) | ||
|
||
var ( | ||
errOutOfCapacity = queue.ErrOutOfCapacity | ||
errEmptyQueue = queue.ErrEmptyQueue | ||
) | ||
|
||
func TestNewConcurrentPriorityQueue(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
q *ConcurrentPriorityQueue[int] | ||
capacity int | ||
data []int | ||
expect []int | ||
}{ | ||
{ | ||
name: "无边界", | ||
q: NewConcurrentPriorityQueue(0, comparator.PrimeComparator[int]), | ||
capacity: 0, | ||
data: []int{6, 5, 4, 3, 2, 1}, | ||
expect: []int{1, 2, 3, 4, 5, 6}, | ||
}, | ||
{ | ||
name: "有边界 ", | ||
q: NewConcurrentPriorityQueue(6, comparator.PrimeComparator[int]), | ||
capacity: 6, | ||
data: []int{6, 5, 4, 3, 2, 1}, | ||
expect: []int{1, 2, 3, 4, 5, 6}, | ||
}, | ||
} | ||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
assert.Equal(t, 0, tc.q.Len()) | ||
for _, d := range tc.data { | ||
require.NoError(t, tc.q.Enqueue(d)) | ||
} | ||
assert.Equal(t, tc.capacity, tc.q.Cap()) | ||
assert.Equal(t, len(tc.data), tc.q.Len()) | ||
res := make([]int, 0, len(tc.data)) | ||
for tc.q.Len() > 0 { | ||
head, err := tc.q.Peek() | ||
require.NoError(t, err) | ||
el, err := tc.q.Dequeue() | ||
require.NoError(t, err) | ||
assert.Equal(t, head, el) | ||
res = append(res, el) | ||
} | ||
assert.Equal(t, tc.expect, res) | ||
}) | ||
|
||
} | ||
|
||
} | ||
|
||
// 多个go routine 执行入队操作,完成后,主携程把元素逐一出队,只要有序,可以认为并发入队没问题 | ||
func TestConcurrentPriorityQueue_Enqueue(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
capacity int | ||
concurrency int | ||
perRoutine int | ||
wantSlice []int | ||
remain int | ||
wantErr error | ||
errCount int | ||
}{ | ||
{ | ||
name: "不超过capacity", | ||
capacity: 1100, | ||
concurrency: 100, | ||
perRoutine: 10, | ||
}, | ||
{ | ||
name: "超过capacity", | ||
capacity: 1000, | ||
concurrency: 101, | ||
perRoutine: 10, | ||
wantErr: errOutOfCapacity, | ||
errCount: 10, | ||
}, | ||
} | ||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
q := NewConcurrentPriorityQueue[int](tc.capacity, comparator.PrimeComparator[int]) | ||
wg := sync.WaitGroup{} | ||
wg.Add(tc.concurrency) | ||
errChan := make(chan error, tc.capacity) | ||
for i := tc.concurrency; i > 0; i-- { | ||
go func(i int) { | ||
start := i * 10 | ||
for j := 0; j < tc.perRoutine; j++ { | ||
err := q.Enqueue(start + j) | ||
if err != nil { | ||
errChan <- err | ||
} | ||
} | ||
wg.Done() | ||
}(i) | ||
} | ||
wg.Wait() | ||
assert.Equal(t, tc.errCount, len(errChan)) | ||
prev := -1 | ||
for q.Len() > 0 { | ||
el, _ := q.Dequeue() | ||
assert.Less(t, prev, el) | ||
|
||
// 入队元素总数小于capacity时,应该所有元素都入队了,出队顺序应该依次加1 | ||
if prev > -1 && len(errChan) == 0 { | ||
assert.Equal(t, prev+1, el) | ||
} | ||
prev = el | ||
} | ||
}) | ||
|
||
} | ||
} | ||
|
||
// 预先入队一组数据,通过测试多个协程并发出队时,每个协程内出队元素有序,间接确认并发安全 | ||
func TestConcurrentPriorityQueue_Dequeue(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
total int | ||
concurrency int | ||
perRoutine int | ||
wantSlice []int | ||
remain int | ||
wantErr error | ||
errCount int | ||
}{ | ||
{ | ||
name: "入队大于出队", | ||
total: 910, | ||
concurrency: 100, | ||
perRoutine: 9, | ||
remain: 10, | ||
}, | ||
{ | ||
name: "入队小于出队", | ||
total: 900, | ||
concurrency: 101, | ||
perRoutine: 9, | ||
remain: 0, | ||
wantErr: errEmptyQueue, | ||
errCount: 9, | ||
}, | ||
} | ||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
q := NewConcurrentPriorityQueue[int](tc.total, comparator.PrimeComparator[int]) | ||
for i := tc.total; i > 0; i-- { | ||
require.NoError(t, q.Enqueue(i)) | ||
} | ||
|
||
resultChan := make(chan int, tc.concurrency*tc.perRoutine) | ||
disOrderChan := make(chan bool, tc.concurrency*tc.perRoutine) | ||
errChan := make(chan error, tc.errCount) | ||
wg := sync.WaitGroup{} | ||
wg.Add(tc.concurrency) | ||
|
||
for i := 0; i < tc.concurrency; i++ { | ||
go func() { | ||
prev := -1 | ||
for i := 0; i < tc.perRoutine; i++ { | ||
el, err := q.Dequeue() | ||
if err != nil { | ||
// 如果出队报错,把错误放到error通道,以便后续检查错误的内容和数量是否符合预期 | ||
errChan <- err | ||
} else { | ||
// 如果出队不报错,则检查出队结果是否符合预期 | ||
resultChan <- el | ||
if prev >= el { | ||
disOrderChan <- false | ||
} | ||
prev = el | ||
} | ||
|
||
} | ||
wg.Done() | ||
}() | ||
} | ||
wg.Wait() | ||
close(resultChan) | ||
close(errChan) | ||
close(disOrderChan) | ||
|
||
// 检查并发出队的元素数量,是否符合预期 | ||
assert.Equal(t, tc.remain, q.Len()) | ||
|
||
// 检查所有协程中的执行错误,是否符合预期 | ||
assert.Equal(t, tc.errCount, len(errChan)) | ||
for err := range errChan { | ||
assert.Equal(t, tc.wantErr, err) | ||
} | ||
|
||
// 每个协程内部,出队元素应该有序,检查是否发现无序的情况 | ||
assert.Equal(t, 0, len(disOrderChan)) | ||
|
||
// 每个协程的每次出队操作,出队元素都应该不同,检查是否符合预期 | ||
resultSet := make(map[int]bool) | ||
for el := range resultChan { | ||
_, ok := resultSet[el] | ||
assert.Equal(t, false, ok) | ||
resultSet[el] = true | ||
} | ||
|
||
}) | ||
|
||
} | ||
} | ||
|
||
// 测试同时并发出入队。只要并发安全,并发出入队后的剩余元素数量+报错数量应该符合预期 | ||
// TODO 有待设计更好的并发出入队测试方案 | ||
func TestConcurrentPriorityQueue_EnqueueDequeue(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
enqueue int | ||
dequeue int | ||
remain int | ||
}{ | ||
{ | ||
name: "出队等于入队", | ||
enqueue: 50, | ||
dequeue: 50, | ||
remain: 0, | ||
}, | ||
{ | ||
name: "出队小于入队", | ||
enqueue: 50, | ||
dequeue: 40, | ||
remain: 10, | ||
}, | ||
{ | ||
name: "出队大于入队", | ||
enqueue: 50, | ||
dequeue: 60, | ||
remain: -10, | ||
}, | ||
} | ||
for _, tt := range testCases { | ||
tc := tt | ||
t.Run(tc.name, func(t *testing.T) { | ||
q := NewConcurrentPriorityQueue[int](0, comparator.PrimeComparator[int]) | ||
errChan := make(chan error, tc.dequeue) | ||
wg := sync.WaitGroup{} | ||
wg.Add(tc.enqueue + tc.dequeue) | ||
go func() { | ||
for i := 0; i < tc.enqueue; i++ { | ||
go func(i int) { | ||
require.NoError(t, q.Enqueue(i)) | ||
wg.Done() | ||
}(i) | ||
} | ||
}() | ||
go func() { | ||
for i := 0; i < tc.dequeue; i++ { | ||
_, err := q.Dequeue() | ||
if err != nil { | ||
errChan <- err | ||
} | ||
wg.Done() | ||
} | ||
}() | ||
|
||
wg.Wait() | ||
close(errChan) | ||
assert.Equal(t, tc.remain, q.Len()-len(errChan)) | ||
}) | ||
} | ||
} | ||
|
||
func ExampleNewConcurrentPriorityQueue() { | ||
q := NewConcurrentPriorityQueue[int](10, comparator.PrimeComparator[int]) | ||
_ = q.Enqueue(3) | ||
_ = q.Enqueue(2) | ||
_ = q.Enqueue(1) | ||
var vals []int | ||
val, _ := q.Dequeue() | ||
vals = append(vals, val) | ||
val, _ = q.Dequeue() | ||
vals = append(vals, val) | ||
val, _ = q.Dequeue() | ||
vals = append(vals, val) | ||
fmt.Println(vals) | ||
// Output: | ||
// [1 2 3] | ||
} |