gopackage v2
import (
"context"
"sync"
constant "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/constant"
)
// use share contextKey
var ContextKey = constant.SharedContextKey
type valuesMap struct {
store map[string][]interface{}
parentCtx context.Context
lock sync.RWMutex
}
func NewValuesMap(parentCtx context.Context) *valuesMap {
return &valuesMap{
store: make(map[string][]interface{}),
parentCtx: parentCtx,
}
}
// Store returns a copy of parent in which the value associated with key is value
func Store[T any](ctx context.Context, key string, values ...T) context.Context {
vm, ok := ctx.Value(ContextKey).(*valuesMap)
if !ok {
vm = NewValuesMap(ctx)
ctx = context.WithValue(ctx, ContextKey, vm)
}
vm.lock.Lock()
defer vm.lock.Unlock()
for _, value := range values {
vm.store[key] = append(vm.store[key], value)
}
return ctx
}
// StoreSingleValue returns a copy of parent in which the value associated with key is value
func StoreSingleValue[T any](ctx context.Context, key string, value T) context.Context {
vm, ok := ctx.Value(ContextKey).(*valuesMap)
if !ok {
vm = NewValuesMap(ctx)
ctx = context.WithValue(ctx, ContextKey, vm)
}
vm.lock.Lock()
defer vm.lock.Unlock()
vm.store[key] = []interface{}{value}
return ctx
}
// Load returns the first value stored in the context for the given key.
func Load[T any](ctx context.Context, key string) (value T) {
values := LoadAll[T](ctx, key)
if len(values) > 0 {
return values[0]
}
return
}
// LoadAll returns all value associated with this context for key
func LoadAll[T any](ctx context.Context, key string) []T {
vm, ok := ctx.Value(ContextKey).(*valuesMap)
if !ok {
// handle old version
oldMap, ok := ctx.Value(ContextKey).(map[string][]interface{})
if ok {
if oldValues, found := oldMap[key]; found {
var result []T
for _, v := range oldValues {
if value, ok := v.(T); ok {
result = append(result, value)
}
}
return result
}
return nil
}
return nil
}
vm.lock.RLock()
defer vm.lock.RUnlock()
values, ok := vm.store[key]
if !ok || len(values) == 0 {
// 递归查找父上下文
if vm.parentCtx != nil {
return LoadAll[T](vm.parentCtx, key)
}
return nil
}
var result []T
for _, v := range values {
if value, ok := v.(T); ok {
result = append(result, value)
}
}
return result
}
gopackage v2_test
import (
"context"
"fmt"
"os"
"sync"
"testing"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper"
v2 "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/v2"
)
type TestSuite struct {
suite.Suite
}
func (s *TestSuite) TestRun() {
ctx := context.WithValue(context.TODO(), "parent", "1")
requestId := uuid.NewV4().String()
requestContext := "POST:/ping"
uin := "1"
pid := os.Getpid()
ctx = v2.Store(ctx, "requestId", requestId)
ctx = v2.Store(ctx, "requestContext", requestContext)
ctx = v2.Store(ctx, "uin", uin)
ctx = v2.Store(ctx, "pid", pid)
// 使用标准 context 库完全不受影响
ctx = context.WithValue(ctx, "son", "2")
s.Equal(requestId, v2.Load[string](ctx, "requestId"))
s.Equal(requestContext, v2.Load[string](ctx, "requestContext"))
s.Equal(uin, v2.Load[string](ctx, "uin"))
s.Equal(pid, v2.Load[int](ctx, "pid"))
s.Equal("1", ctx.Value("parent"))
s.Equal("2", ctx.Value("son"))
// 存入多个值映射到相同 key
ctx = v2.Store(ctx, "uin", "2")
ctx = v2.Store(ctx, "uin", "3")
s.Equal([]string{"1", "2", "3"}, v2.LoadAll[string](ctx, "uin"))
}
func TestAll(t *testing.T) {
suite.Run(t, &TestSuite{})
}
func TestStoreMultipleValues(t *testing.T) {
ctx := context.Background()
requestId := uuid.NewV4().String()
// 存储多个值到相同的键
ctx = v2.Store(ctx, "requestId", requestId, "anotherId", "yetAnotherId")
// 从context中加载所有的值
values := v2.LoadAll[string](ctx, "requestId")
assert.Equal(t, 3, len(values))
assert.Equal(t, requestId, values[0])
assert.Equal(t, "anotherId", values[1])
assert.Equal(t, "yetAnotherId", values[2])
}
func TestStoreSingleValue(t *testing.T) {
t.Run("", func(t *testing.T) {
ctx := context.Background()
requestId := uuid.NewV4().String()
ctx = v2.StoreSingleValue(ctx, "requestId", requestId)
t.Logf("ctx value:%+v", v2.Load[string](ctx, "requestId"))
newRequestId := uuid.NewV4().String()
ctx = v2.StoreSingleValue(ctx, "requestId", newRequestId)
t.Logf("ctx value:%+v", v2.Load[string](ctx, "requestId"))
})
}
func TestWriteDoesNotAffectParentContext(t *testing.T) {
parentCtx := context.Background()
childCtx := v2.Store(parentCtx, "childKey", "childValue")
// 子context应有新值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 父context不应有子context的值
parentValue := v2.Load[string](parentCtx, "childKey")
assert.Equal(t, "", parentValue)
}
// TestCascadingReadFromParentContext
// TODO 优化
func TestCascadingReadFromParentContext(t *testing.T) {
parentCtx := context.Background()
parentCtx = v2.Store(parentCtx, "parentKey", "parentValue")
childCtx := v2.Store(parentCtx, "childKey", "childValue")
// 子context应有新值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 子context应级联读取父context的值
parentValue := v2.Load[string](childCtx, "parentKey")
assert.Equal(t, "parentValue", parentValue)
}
// 测试并发功能
func TestConcurrentReadWrite(t *testing.T) {
ctx := context.Background()
// 创建一个临时上下文,以便在写入完成后切换回
tmpCtx := context.Background()
// 并发写入
var wg sync.WaitGroup
for i := 0; i < 10000; i++ {
uniqueKey := fmt.Sprintf("concurrentKey_%d", i)
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
tmpCtx = v2.Store(tmpCtx, key, i)
}(i, uniqueKey)
}
// 确保所有写入操作完成
wg.Wait()
// 将写入完成后的临时上下文赋值给原始上下文
ctx = tmpCtx
// 并发读取
wg = sync.WaitGroup{}
for i := 0; i < 10000; i++ {
uniqueKey := fmt.Sprintf("concurrentKey_%d", i)
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
value := v2.Load[int](ctx, key)
if value != i {
t.Errorf("键 %s 的值应该是 %d,但得到的是 %v", key, i, value)
}
}(i, uniqueKey)
}
// 等待所有读取完成
wg.Wait()
// 如果没有出现任何 Error 或 Fatal,那么测试就是成功的
t.Log("TestConcurrentReadWrite passed")
}
// TestConcurrentReadWhileWrite 边读边写,没有报错就是成功
func TestConcurrentReadWhileWrite(t *testing.T) {
originalCtx := context.Background()
var ctx context.Context
var wg sync.WaitGroup
key := "concurrentKey"
// 定义写入函数
writeFunc := func(value string) {
ctx = v2.Store(originalCtx, key, value)
}
// 定义读取函数
readFunc := func() string {
return v2.Load[string](ctx, key)
}
// 启动多个写入 goroutines
for i := 0; i < 10000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
value := fmt.Sprintf("value%d", i)
writeFunc(value)
}(i)
}
// 启动多个读取 goroutines
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
readValue := readFunc()
t.Log("Read value:", readValue)
// 这里不使用断言,因为读取到的值可能会在写入时改变
}()
}
// 等待所有 goroutines 完成
wg.Wait()
}
func TestCompatibility(t *testing.T) {
ctx := context.Background()
// 使用旧版本存储值
ctx = contexthelper.Store(ctx, "testKey", "testValue")
// 尝试使用新版本读取值
retrievedValue := v2.Load[string](ctx, "testKey")
assert.Equal(t, "testValue", retrievedValue, "The value retrieved by new version should match the value stored by old version.")
}
// TestCascadingRead 测试上下文级联读取
func TestCascadingRead(t *testing.T) {
// 创建父上下文并存储值
parentCtx := context.Background()
parentCtx = v2.Store(parentCtx, "parentKey", "parentValue")
// 验证父上下文中是否正确存储了值
parentValue := v2.Load[string](parentCtx, "parentKey")
assert.Equal(t, "parentValue", parentValue, "Value should be stored in parent context")
// 创建子上下文,并传递父上下文引用
childCtx := context.WithValue(parentCtx, v2.ContextKey, v2.NewValuesMap(parentCtx))
// 子上下文应该能够级联读取父上下文的值
childValue := v2.Load[string](childCtx, "parentKey")
assert.Equal(t, "parentValue", childValue, "Should retrieve value from parent context")
// 在子上下文中存储值
childCtx = v2.Store(childCtx, "childKey", "childValue")
// 确保子上下文可以读取自己存储的值
valueFromChild := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", valueFromChild, "Should retrieve value stored in child context")
// 确保父上下文不能读取子上下文存储的值
valueFromParent := v2.Load[string](parentCtx, "childKey")
assert.Equal(t, "", valueFromParent, "Parent context should not retrieve value stored in child context")
}
// TestCascadingReadFromParentContextV2 测试上下文级联读取
func TestCascadingReadFromParentContextV2(t *testing.T) {
// 创建父上下文并存储值
parentCtx := context.Background()
parentCtx = v2.Store(parentCtx, "parentKey", "parentValue")
// 创建子上下文,不在子上下文中存储相同的键
childCtx := context.WithValue(parentCtx, v2.ContextKey, v2.NewValuesMap(parentCtx))
// 子上下文应该能够级联读取父上下文的值
parentValue := v2.Load[string](childCtx, "parentKey")
assert.Equal(t, "parentValue", parentValue)
// 在子上下文中存储值
childCtx = v2.Store(childCtx, "childKey", "childValue")
// 确保子上下文可以读取自己存储的值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 确保父上下文不能读取子上下文存储的值
parentValueFromChildKey := v2.Load[string](parentCtx, "childKey")
assert.Equal(t, "", parentValueFromChildKey)
}
// TestWriteDoesNotAffectParent 测试写入只影响当前上下文
func TestWriteDoesNotAffectParent(t *testing.T) {
parentCtx := context.Background()
childCtx := context.WithValue(parentCtx, v2.ContextKey, v2.NewValuesMap(parentCtx))
// 在子上下文中存储值
v2.Store(childCtx, "childKey", "childValue")
// 确认父上下文没有被影响
parentValue := v2.Load[string](parentCtx, "childKey")
assert.Empty(t, parentValue, "Parent context should not have the child's value")
}
gopackage v2
import (
"context"
"sync"
constant "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/constant"
)
// use share contextKey
var ContextKey = constant.SharedContextKey
type valuesMap struct {
store map[string][]interface{}
parentCtx context.Context
lock sync.RWMutex
}
func newValuesMap(parentCtx context.Context) *valuesMap {
return &valuesMap{
store: make(map[string][]interface{}),
parentCtx: parentCtx,
}
}
// Store returns a copy of parent in which the value associated with key is value
func Store[T any](ctx context.Context, key string, values ...T) context.Context {
vm, ok := ctx.Value(ContextKey).(*valuesMap)
if !ok {
vm = newValuesMap(ctx)
ctx = context.WithValue(ctx, ContextKey, vm)
}
vm.lock.Lock()
defer vm.lock.Unlock()
for _, value := range values {
vm.store[key] = append(vm.store[key], value)
}
return ctx
}
// StoreSingleValue returns a copy of parent in which the value associated with key is value
func StoreSingleValue[T any](ctx context.Context, key string, value T) context.Context {
vm, ok := ctx.Value(ContextKey).(*valuesMap)
if !ok {
vm = newValuesMap(ctx)
ctx = context.WithValue(ctx, ContextKey, vm)
}
vm.lock.Lock()
defer vm.lock.Unlock()
vm.store[key] = []interface{}{value}
return ctx
}
// Load returns the first value stored in the context for the given key.
func Load[T any](ctx context.Context, key string) (value T) {
values := LoadAll[T](ctx, key)
if len(values) > 0 {
return values[0]
}
return
}
// LoadAll returns all value associated with this context for key
func LoadAll[T any](ctx context.Context, key string) []T {
vm, ok := ctx.Value(ContextKey).(*valuesMap)
if !ok {
// handle old version
oldMap, ok := ctx.Value(ContextKey).(map[string][]interface{})
if ok {
if oldValues, found := oldMap[key]; found {
var result []T
for _, v := range oldValues {
if value, ok := v.(T); ok {
result = append(result, value)
}
}
return result
}
return nil
}
return nil
}
vm.lock.RLock()
defer vm.lock.RUnlock()
values, ok := vm.store[key]
if !ok {
return nil
}
var result []T
for _, v := range values {
if value, ok := v.(T); ok {
result = append(result, value)
}
}
return result
}
gopackage v2_test
import (
"context"
"fmt"
"os"
"sync"
"testing"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper"
v2 "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/v2"
)
type TestSuite struct {
suite.Suite
}
func (s *TestSuite) TestRun() {
ctx := context.WithValue(context.TODO(), "parent", "1")
requestId := uuid.NewV4().String()
requestContext := "POST:/ping"
uin := "1"
pid := os.Getpid()
ctx = v2.Store(ctx, "requestId", requestId)
ctx = v2.Store(ctx, "requestContext", requestContext)
ctx = v2.Store(ctx, "uin", uin)
ctx = v2.Store(ctx, "pid", pid)
// 使用标准 context 库完全不受影响
ctx = context.WithValue(ctx, "son", "2")
s.Equal(requestId, v2.Load[string](ctx, "requestId"))
s.Equal(requestContext, v2.Load[string](ctx, "requestContext"))
s.Equal(uin, v2.Load[string](ctx, "uin"))
s.Equal(pid, v2.Load[int](ctx, "pid"))
s.Equal("1", ctx.Value("parent"))
s.Equal("2", ctx.Value("son"))
// 存入多个值映射到相同 key
ctx = v2.Store(ctx, "uin", "2")
ctx = v2.Store(ctx, "uin", "3")
s.Equal([]string{"1", "2", "3"}, v2.LoadAll[string](ctx, "uin"))
}
func TestAll(t *testing.T) {
suite.Run(t, &TestSuite{})
}
func TestStoreMultipleValues(t *testing.T) {
ctx := context.Background()
requestId := uuid.NewV4().String()
// 存储多个值到相同的键
ctx = v2.Store(ctx, "requestId", requestId, "anotherId", "yetAnotherId")
// 从context中加载所有的值
values := v2.LoadAll[string](ctx, "requestId")
assert.Equal(t, 3, len(values))
assert.Equal(t, requestId, values[0])
assert.Equal(t, "anotherId", values[1])
assert.Equal(t, "yetAnotherId", values[2])
}
func TestStoreSingleValue(t *testing.T) {
t.Run("", func(t *testing.T) {
ctx := context.Background()
requestId := uuid.NewV4().String()
ctx = v2.StoreSingleValue(ctx, "requestId", requestId)
t.Logf("ctx value:%+v", v2.Load[string](ctx, "requestId"))
newRequestId := uuid.NewV4().String()
ctx = v2.StoreSingleValue(ctx, "requestId", newRequestId)
t.Logf("ctx value:%+v", v2.Load[string](ctx, "requestId"))
})
}
func TestWriteDoesNotAffectParentContext(t *testing.T) {
parentCtx := context.Background()
childCtx := v2.Store(parentCtx, "childKey", "childValue")
// 子context应有新值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 父context不应有子context的值
parentValue := v2.Load[string](parentCtx, "childKey")
assert.Equal(t, "", parentValue)
}
func TestCascadingReadFromParentContext(t *testing.T) {
parentCtx := context.Background()
parentCtx = v2.Store(parentCtx, "parentKey", "parentValue")
childCtx := v2.Store(parentCtx, "childKey", "childValue")
// 子context应有新值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 子context应级联读取父context的值
parentValue := v2.Load[string](childCtx, "parentKey")
assert.Equal(t, "parentValue", parentValue)
}
// 测试并发功能
func TestConcurrentReadWrite(t *testing.T) {
ctx := context.Background()
// 创建一个临时上下文,以便在写入完成后切换回
tmpCtx := context.Background()
// 并发写入
var wg sync.WaitGroup
for i := 0; i < 10000; i++ {
uniqueKey := fmt.Sprintf("concurrentKey_%d", i)
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
tmpCtx = v2.Store(tmpCtx, key, i)
}(i, uniqueKey)
}
// 确保所有写入操作完成
wg.Wait()
// 将写入完成后的临时上下文赋值给原始上下文
ctx = tmpCtx
// 并发读取
wg = sync.WaitGroup{}
for i := 0; i < 10000; i++ {
uniqueKey := fmt.Sprintf("concurrentKey_%d", i)
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
value := v2.Load[int](ctx, key)
if value != i {
t.Errorf("键 %s 的值应该是 %d,但得到的是 %v", key, i, value)
}
}(i, uniqueKey)
}
// 等待所有读取完成
wg.Wait()
// 如果没有出现任何 Error 或 Fatal,那么测试就是成功的
t.Log("TestConcurrentReadWrite passed")
}
// TestConcurrentReadWhileWrite 边读边写,没有报错就是成功
func TestConcurrentReadWhileWrite(t *testing.T) {
originalCtx := context.Background()
var ctx context.Context
var wg sync.WaitGroup
key := "concurrentKey"
// 定义写入函数
writeFunc := func(value string) {
ctx = v2.Store(originalCtx, key, value)
}
// 定义读取函数
readFunc := func() string {
return v2.Load[string](ctx, key)
}
// 启动多个写入 goroutines
for i := 0; i < 10000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
value := fmt.Sprintf("value%d", i)
writeFunc(value)
}(i)
}
// 启动多个读取 goroutines
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
readValue := readFunc()
t.Log("Read value:", readValue)
// 这里不使用断言,因为读取到的值可能会在写入时改变
}()
}
// 等待所有 goroutines 完成
wg.Wait()
}
func TestCompatibility(t *testing.T) {
ctx := context.Background()
// 使用旧版本存储值
ctx = contexthelper.Store(ctx, "testKey", "testValue")
// 尝试使用新版本读取值
retrievedValue := v2.Load[string](ctx, "testKey")
assert.Equal(t, "testValue", retrievedValue, "The value retrieved by new version should match the value stored by old version.")
}
gopackage v2_test
import (
"context"
"fmt"
v2 "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/v2"
)
func ExampleStore() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1")
value := v2.Load[string](ctx, "key1")
fmt.Println(value)
// Output: value1
}
func ExampleStore_multipleValues() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1", "value2", "value3")
values := v2.LoadAll[string](ctx, "key1")
fmt.Println(values)
// Output: [value1 value2 value3]
}
func ExampleStoreSingleValue() {
ctx := context.Background()
ctx = v2.StoreSingleValue(ctx, "key1", "value1")
ctx = v2.StoreSingleValue(ctx, "key1", "value2")
ctx = v2.StoreSingleValue(ctx, "key1", "value3")
value := v2.Load[string](ctx, "key1")
fmt.Println(value)
// Output: value3
}
func ExampleLoad() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1")
ctx = v2.Store(ctx, "key1", "value2")
ctx = v2.Store(ctx, "key1", "value3")
value := v2.Load[string](ctx, "key1")
fmt.Println(value)
// Output: value1
}
func ExampleLoadAll() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1")
ctx = v2.Store(ctx, "key1", "value2")
ctx = v2.Store(ctx, "key1", "value3")
values := v2.LoadAll[string](ctx, "key1")
fmt.Println(values)
// Output: [value1 value2 value3]
}
gopackage v2_test
import (
"context"
"fmt"
"os"
"sync"
"testing"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
v2 "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/v2"
)
type TestSuite struct {
suite.Suite
}
func (s *TestSuite) TestRun() {
ctx := context.WithValue(context.TODO(), "parent", "1")
requestId := uuid.NewV4().String()
requestContext := "POST:/ping"
uin := "1"
pid := os.Getpid()
ctx = v2.Store(ctx, "requestId", requestId)
ctx = v2.Store(ctx, "requestContext", requestContext)
ctx = v2.Store(ctx, "uin", uin)
ctx = v2.Store(ctx, "pid", pid)
// 使用标准 context 库完全不受影响
ctx = context.WithValue(ctx, "son", "2")
s.Equal(requestId, v2.Load[string](ctx, "requestId"))
s.Equal(requestContext, v2.Load[string](ctx, "requestContext"))
s.Equal(uin, v2.Load[string](ctx, "uin"))
s.Equal(pid, v2.Load[int](ctx, "pid"))
s.Equal("1", ctx.Value("parent"))
s.Equal("2", ctx.Value("son"))
// 存入多个值映射到相同 key
ctx = v2.Store(ctx, "uin", "2")
ctx = v2.Store(ctx, "uin", "3")
s.Equal([]string{"1", "2", "3"}, v2.LoadAll[string](ctx, "uin"))
}
func TestAll(t *testing.T) {
suite.Run(t, &TestSuite{})
}
func TestStoreMultipleValues(t *testing.T) {
ctx := context.Background()
requestId := uuid.NewV4().String()
// 存储多个值到相同的键
ctx = v2.Store(ctx, "requestId", requestId, "anotherId", "yetAnotherId")
// 从context中加载所有的值
values := v2.LoadAll[string](ctx, "requestId")
assert.Equal(t, 3, len(values))
assert.Equal(t, requestId, values[0])
assert.Equal(t, "anotherId", values[1])
assert.Equal(t, "yetAnotherId", values[2])
}
func TestStoreSingleValue(t *testing.T) {
t.Run("", func(t *testing.T) {
ctx := context.Background()
requestId := uuid.NewV4().String()
ctx = v2.StoreSingleValue(ctx, "requestId", requestId)
t.Logf("ctx value:%+v", v2.Load[string](ctx, "requestId"))
newRequestId := uuid.NewV4().String()
ctx = v2.StoreSingleValue(ctx, "requestId", newRequestId)
t.Logf("ctx value:%+v", v2.Load[string](ctx, "requestId"))
})
}
func TestWriteDoesNotAffectParentContext(t *testing.T) {
parentCtx := context.Background()
childCtx := v2.Store(parentCtx, "childKey", "childValue")
// 子context应有新值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 父context不应有子context的值
parentValue := v2.Load[string](parentCtx, "childKey")
assert.Nil(t, parentValue)
}
func TestCascadingReadFromParentContext(t *testing.T) {
parentCtx := context.Background()
parentCtx = v2.Store(parentCtx, "parentKey", "parentValue")
childCtx := v2.Store(parentCtx, "childKey", "childValue")
// 子context应有新值
childValue := v2.Load[string](childCtx, "childKey")
assert.Equal(t, "childValue", childValue)
// 子context应级联读取父context的值
parentValue := v2.Load[string](childCtx, "parentKey")
assert.Equal(t, "parentValue", parentValue)
}
// 测试并发功能
func TestConcurrentReadWrite(t *testing.T) {
ctx := context.Background()
// 创建一个临时上下文,以便在写入完成后切换回
tmpCtx := context.Background()
// 并发写入
var wg sync.WaitGroup
for i := 0; i < 10000; i++ {
uniqueKey := fmt.Sprintf("concurrentKey_%d", i)
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
tmpCtx = v2.Store(tmpCtx, key, i)
}(i, uniqueKey)
}
// 确保所有写入操作完成
wg.Wait()
// 将写入完成后的临时上下文赋值给原始上下文
ctx = tmpCtx
// 并发读取
wg = sync.WaitGroup{}
for i := 0; i < 10000; i++ {
uniqueKey := fmt.Sprintf("concurrentKey_%d", i)
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
value := v2.Load[int](ctx, key)
if value != i {
t.Errorf("键 %s 的值应该是 %d,但得到的是 %v", key, i, value)
}
}(i, uniqueKey)
}
// 等待所有读取完成
wg.Wait()
// 如果没有出现任何 Error 或 Fatal,那么测试就是成功的
t.Log("TestConcurrentReadWrite passed")
}
// TestConcurrentReadWhileWrite 边读边写,没有报错就是成功
func TestConcurrentReadWhileWrite(t *testing.T) {
originalCtx := context.Background()
var ctx context.Context
var wg sync.WaitGroup
key := "concurrentKey"
// 定义写入函数
writeFunc := func(value string) {
ctx = v2.Store(originalCtx, key, value)
}
// 定义读取函数
readFunc := func() string {
return v2.Load[string](ctx, key)
}
// 启动多个写入 goroutines
for i := 0; i < 10000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
value := fmt.Sprintf("value%d", i)
writeFunc(value)
}(i)
}
// 启动多个读取 goroutines
for i := 0; i < 10000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
readValue := readFunc()
t.Log("Read value:", readValue)
// 这里不使用断言,因为读取到的值可能会在写入时改变
}()
}
// 等待所有 goroutines 完成
wg.Wait()
}
gopackage v2
import (
"context"
"sync"
)
type k int
var contextKey = k(0)
type valuesMap struct {
store map[string][]interface{}
parentCtx context.Context
lock sync.RWMutex
}
func newValuesMap(parentCtx context.Context) *valuesMap {
return &valuesMap{
store: make(map[string][]interface{}),
parentCtx: parentCtx,
}
}
// Store returns a copy of parent in which the value associated with key is value
func Store[T any](ctx context.Context, key string, values ...T) context.Context {
vm, ok := ctx.Value(contextKey).(*valuesMap)
if !ok {
vm = newValuesMap(ctx)
ctx = context.WithValue(ctx, contextKey, vm)
}
vm.lock.Lock()
defer vm.lock.Unlock()
for _, value := range values {
vm.store[key] = append(vm.store[key], value)
}
return ctx
}
// StoreSingleValue returns a copy of parent in which the value associated with key is value
func StoreSingleValue[T any](ctx context.Context, key string, value T) context.Context {
vm, ok := ctx.Value(contextKey).(*valuesMap)
if !ok {
vm = newValuesMap(ctx)
ctx = context.WithValue(ctx, contextKey, vm)
}
vm.lock.Lock()
defer vm.lock.Unlock()
vm.store[key] = []interface{}{value}
return ctx
}
// Load returns the first value stored in the context for the given key.
func Load[T any](ctx context.Context, key string) (value T) {
values := LoadAll[T](ctx, key)
if len(values) > 0 {
return values[0]
}
return
}
// LoadAll returns all value associated with this context for key
func LoadAll[T any](ctx context.Context, key string) []T {
vm, ok := ctx.Value(contextKey).(*valuesMap)
if !ok {
return nil
}
vm.lock.RLock()
defer vm.lock.RUnlock()
values, ok := vm.store[key]
if !ok {
return nil
}
var result []T
for _, v := range values {
value, ok := v.(T)
if ok {
result = append(result, value)
}
}
return result
}
gopackage v2_test
import (
"context"
"fmt"
v2 "gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper/v2"
)
func ExampleStore() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1")
value := v2.Load[string](ctx, "key1")
fmt.Println(value)
// Output: value1
}
func ExampleStore_multipleValues() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1", "value2", "value3")
values := v2.LoadAll[string](ctx, "key1")
fmt.Println(values)
// Output: [value1 value2 value3]
}
func ExampleStoreSingleValue() {
ctx := context.Background()
ctx = v2.StoreSingleValue(ctx, "key1", "value1")
ctx = v2.StoreSingleValue(ctx, "key1", "value2")
ctx = v2.StoreSingleValue(ctx, "key1", "value3")
value := v2.Load[string](ctx, "key1")
fmt.Println(value)
// Output: value3
}
func ExampleLoad() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1")
ctx = v2.Store(ctx, "key1", "value2")
ctx = v2.Store(ctx, "key1", "value3")
value := v2.Load[string](ctx, "key1")
fmt.Println(value)
// Output: value1
}
func ExampleLoadAll() {
ctx := context.Background()
ctx = v2.Store(ctx, "key1", "value1")
ctx = v2.Store(ctx, "key1", "value2")
ctx = v2.Store(ctx, "key1", "value3")
values := v2.LoadAll[string](ctx, "key1")
fmt.Println(values)
// Output: [value1 value2 value3]
}
gopackage contexthelper
import (
"context"
)
type k int
var contextKey = k(0)
// Store returns a copy of parent in which the value associated with key is value
func Store(ctx context.Context, key string, value interface{}) context.Context {
var m map[string][]interface{}
i := ctx.Value(contextKey)
if i == nil {
m = map[string][]interface{}{}
ctx = context.WithValue(ctx, contextKey, m)
} else {
m = i.(map[string][]interface{})
}
m[key] = append(m[key], value)
return ctx
}
// StoreSingleValue returns a copy of parent in which the value associated with key is value
func StoreSingleValue(ctx context.Context, key string, value interface{}) context.Context {
var m map[string][]interface{}
i := ctx.Value(contextKey)
if i == nil {
m = map[string][]interface{}{}
ctx = context.WithValue(ctx, contextKey, m)
} else {
m = i.(map[string][]interface{})
}
m[key] = []interface{}{value}
return ctx
}
// Load returns the first value associated with this context for key
func Load(ctx context.Context, key string) interface{} {
values := LoadAll(ctx, key)
if len(values) < 1 {
return nil
}
return values[0]
}
// LoadAll returns all value associated with this context for key
func LoadAll(ctx context.Context, key string) []interface{} {
i := ctx.Value(contextKey)
if i == nil {
return nil
}
return i.(map[string][]interface{})[key]
}
gopackage contexthelper_test
import (
"context"
"os"
"testing"
"gl.fotechwealth.com.local/backend/trade-lib.git/contexthelper"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/suite"
)
type TestSuite struct {
suite.Suite
}
func (s *TestSuite) TestRun() {
ctx := context.WithValue(context.TODO(), "parent", "1")
requestId := uuid.NewV4().String()
requestContext := "POST:/ping"
uin := "1"
pid := os.Getpid()
ctx = contexthelper.Store(ctx, "requestId", requestId)
ctx = contexthelper.Store(ctx, "requestContext", requestContext)
ctx = contexthelper.Store(ctx, "uin", uin)
ctx = contexthelper.Store(ctx, "pid", pid)
// 使用标准 context 库完全不受影响
ctx = context.WithValue(ctx, "son", "2")
s.Equal(requestId, contexthelper.Load(ctx, "requestId"))
s.Equal(requestContext, contexthelper.Load(ctx, "requestContext"))
s.Equal(uin, contexthelper.Load(ctx, "uin"))
s.Equal(pid, contexthelper.Load(ctx, "pid"))
s.Equal("1", ctx.Value("parent"))
s.Equal("2", ctx.Value("son"))
// 存入多个值映射到相同 key
ctx = contexthelper.Store(ctx, "uin", "2")
ctx = contexthelper.Store(ctx, "uin", "3")
s.Equal([]interface{}{"1", "2", "3"}, contexthelper.LoadAll(ctx, "uin"))
}
func TestAll(t *testing.T) {
suite.Run(t, &TestSuite{})
}
func TestStoreSingleValue(t *testing.T) {
t.Run("", func(t *testing.T) {
ctx := context.Background()
ctx = contexthelper.Store(ctx, "requestId", uuid.NewV4().String())
t.Logf("ctx value:%+v", contexthelper.LoadAll(ctx, "requestId"))
ctx = contexthelper.Store(ctx, "requestId", uuid.NewV4().String())
t.Logf("ctx value:%+v", contexthelper.LoadAll(ctx, "requestId"))
ctx = contexthelper.Store(ctx, "requestId", uuid.NewV4().String())
t.Logf("ctx value:%+v", contexthelper.LoadAll(ctx, "requestId"))
})
t.Run("", func(t *testing.T) {
ctx := context.Background()
ctx = contexthelper.StoreSingleValue(ctx, "requestId", uuid.NewV4().String())
t.Logf("ctx value:%+v", contexthelper.LoadAll(ctx, "requestId"))
ctx = contexthelper.StoreSingleValue(ctx, "requestId", uuid.NewV4().String())
t.Logf("ctx value:%+v", contexthelper.LoadAll(ctx, "requestId"))
ctx = contexthelper.StoreSingleValue(ctx, "requestId", uuid.NewV4().String())
t.Logf("ctx value:%+v", contexthelper.LoadAll(ctx, "requestId"))
})
}
本文作者:JIeJaitt
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!