当前位置: 首页 > 后端技术 > Python

Go语言中WaitGroup设计详解

时间:2023-03-25 19:54:47 Python

Go语言提供的协程goroutine让我们可以轻松编写多线程程序,但是如何有效地控制这些并发执行的goroutine是我们需要讨论的问题。正如小菜刀在《Golang并发控制简述》中提到的,在Go标准库提供的同步原语中,锁和原子操作重点控制goroutine之间的数据安全,而WaitGroup、channel、Context控制它们的并发行为。锁、原子操作、通道的实现原理都已经详细分析过了。因此,在本文中,我们将重点介绍WaitGroup。认识WaitGroupWaitGroup是sync包下的内容,用来控制协程之间的同步。WaitGroup的使用场景和名字的意思一样。当我们需要等待一组协程执行完才能进行后续处理时,可以考虑使用它。funcmain(){varwgsync.WaitGroupwg.Add(2)//workernumber2gofunc(){//worker1做某事fmt.Println("goroutine1done!")wg.Done()}()gofunc(){//worker2做某事fmt.Println("goroutine2done!")wg.Done()}()wg.Wait()//等待所有服务员完成fmt.Println("allworkdone!")}//outputgoroutine2完成了!goroutine1完成!所有工作完成!可以看到WaitGroup的使用非常简单,它提供了三个方法。虽然goroutine之间没有父子关系,但是为了方便理解,本文将调用Wait函数的goroutine称为主goroutine,调用Done函数的goroutine称为子goroutine。func(wg*WaitGroup)Add(deltaint)//增加WaitGroup中的子协程计数值func(wg*WaitGroup)Done()//当子协程任务完成时,将计数值减1func(wg*WaitGroup)Wait()//阻塞goroutine调用这个方法,直到计数值为0那么它是如何实现的呢?在源码src/sync/waitgroup.go中可以看到其核心源码不到100行,非常简洁,值得学习。前置知识代码少,但不代表实现简单,容易理解。相反,如果读者没有以下前置知识,真正理解WaitGroup的实现会比较吃力。在解析源码之前,先来梳理一下知识点(如果你已经掌握了,可以直接跳到下面的源码解析部分)。信号量在学习操作系统的时候,我们知道信号量是一种保护共享资源的机制,用来解决多线程同步问题。信号量s是一个具有非负整数值的全局变量,只能由两个特殊操作P和V处理。P(s):如果s非零,则P将s减1并立即返回.如果s为零,则挂起线程直到s变为非零,并等待另一个执行V(s)操作的线程唤醒该线程。唤醒后,P操作将s减1,并将控制权返回给调用者。V(s):V运算将s加1。如果任何线程在等待s变为非零的P操作中被阻塞,则V操作唤醒这些线程之一,然后将s减1并完成其P操作。在Go底层的信号量函数中,runtime_Semacquire(s*uint32)函数会阻塞goroutine,直到信号量s的值大于0,然后原子地递减这个值,即P操作。runtime_Semrelease(s*uint32,lifobool,skipframesint)函数原子地增加信号量的值,然后通知被runtime_Semacquire阻塞的goroutine,即V操作。这两个信号量函数不仅仅在WaitGroup中使用。在《Go精妙的互斥锁设计》一文中,我们发现Go在设计互斥量时,信号量的参与必不可少。内存对齐对于下面的结构,你能回答它占用多少内存吗?typeInsstruct{xbool//1byteyint32//4byteszbyte//1byte}funcmain(){ins:=Ins{}fmt.Printf("inssize:%d,align:%d\n",unsafe.Sizeof(ins),unsafe.Alignof(ins))}//outputinssize:12,align:4根据结构中字段的大小,ins对象占用的内存应该是1+4+1=6个字节,但实际上是12个字节,这是内存对齐造成的。从《CPU缓存体系对Go程序的影响》文章中我们知道CPU的内存读取不是一个字节一个字节的读取,而是一块一块的读取。因此,在类型的值在内存中对齐的情况下,计算机可以高效地加载或写入它们。聚合类型(结构或数组)的内存大小可能大于其元素的总和。编译器将添加未使用的内存地址来填充内存间隙,以确保连续的成员或元素与结构或数组的起始地址对齐。因此,我们在设计结构体时,当结构体成员的类型不同时,在相邻的位置定义相同类型的成员可以节省更多的内存空间。原子操作CASCAS是一种原子操作,可用于在多线程编程中实现不间断的数据交换操作,避免多个线程同时改写某个数据时执行顺序的不确定性和中断的不可预测性数据不一致问题。该操作将内存中的值与指定数据进行比较,当值相同时,将内存中的数据替换为新值。关于Go中原子操作的底层实现,小菜岛在文章《同步原语的基石》中有详细的介绍。移位操作>>和<<在前面关于锁《Go精妙的互斥锁设计》和《Go更细粒度的读写锁设计中》的文章中,我们可以看到很多位操作。灵活的位运算可以将一个普通的数字变成丰富的含义。这里只介绍下面会用到的移位操作。对于左移运算<<,将所有数以二进制形式左移相应位数,高位舍弃,低位补零。在不溢出的前提下,左移一位相当于乘以2的1次方,左移n位相当于乘以2的n次方。右移操作>>,将所有数以二进制形式右移相应位数,低位移出,高位补符号位。右移一位相当于除以2,右移n位相当于除以2的n次方。这是取商,不需要余数。移位操作也可以有非常巧妙的操作,我们将在本文后面看到移位操作的高级应用。Go中的unsafe.Pointer指针和uintptr指针可以分为三类:1、普通类型指针*T,如*int;2.unsafe.Pointer指针;3.uintptr。*T:普通指针类型,用于传递对象地址,不能进行指针计算。unsafe.Pointer指针:通用指针,任何普通类型的指针*T都可以转换为unsafe.Pointer指针,unsafe.Pointer类型的指针也可以转换回普通指针,可以是不同于原来的指针类型*T是一样的。但它不能进行指针计算,也不能读取内存中的值(必须转换为特定类型的普通指针)。uintptr:准确的说,uintptr不是一个指针,它是一个无符号整数,大小不清楚。unsafe.Pointer类型可以与uinptr相互转换。由于uinptr类型存储的是指针指向的地址的值,所以可以通过这个值进行指针操作。GC时,uintptr不会作为指针使用,uintptr类型的对象会被回收。unsafe.Pointer是一个桥梁,可以将任何类型的普通指针相互转换,也可以将任何类型的指针转??换为uintptr进行指针运算。但是unsafe.Pointer和任意类型指针的转换,允许我们向内存中写入任意值,这会破坏Go原有的类型系统,而且由于并不是所有的值都是合法的内存地址,从uintptr到unsafe.Pointer的转换也破坏了类型系统。因此,既然Go将包定义为不安全的,就不能随意使用。源码分析本文基于Go源码1.15.7版本。sync.WaitGroup结构体定义如下,其中包含一个noCopy辅助字段和一个复合含义的state1字段。typeWaitGroupstruct{noCopynoCopy//64位值:高32位是计数器,低32位是服务员计数。//64位原子操作需要64位对齐,但32位//编译器不确保这一点。所以我们分配12个字节,然后使用//其中对齐的8个字节作为状态,其他4个字节作为存储//用于sema。state1[3]uint32}//state返回指向存储在wg.state1.func(wg*WaitGroup)state()(statep*uint64,semap*uint32)中的state和sema字段的指针{//64位编译器地址可以能被8整除,所以可以判断是否是64位对齐ifuintptr(unsafe.Pointer(&wg.state1))%8==0{return(*uint64)(unsafe.Pointer(&wg.state1)),&wg.state1[2]}else{return(*uint64)(unsafe.Pointer(&wg.state1[1])),&wg.state1[0]}}其中noCopy字段是一个空结构体,它做不占用内存,编译器不会用字节填充。主要用于通过govet工具进行静态编译检查,防止开发者在使用过程中复制WaitGroup,造成安全隐患。这部分可以看《no copy机制》了解详情。state1字段是一个长度为3的uint32数组。用来表示三部分:1.Add()设置的子goroutine的计数值计数器;2、Wait()阻塞的等待者数量;3.信号量semap。由于后续操作是对uint64类型的statep进行的,而64位整数的原子操作需要64位对齐,所以32位编译器无法保证这一点。因此,在64位和32位环境下,state1字段的组成和含义是不同的。需要注意的是,当我们初始化一个WaitGroup对象时,它的counter值,waiter值,semap值都是0。Add函数Add()函数的输入参数是一个整数,可以是正数也可以是负数,是计数器值的变化。如果计数器的值变为0,则所有阻塞在Wait()函数中的等待者都会被唤醒;如果计数器值为负,则会引起恐慌。我们去掉racedetection部分的代码,Add()函数源码如下semapstatep信号量值,semap:=wg.state()state:=atomic.AddUint64(statep,uint64(delta)<<32)v:=int32(state>>32)w:=uint32(state)ifv<0{panic("sync:negativeWaitGroupcounter")}ifw!=0&&delta>0&&v==int32(delta){panic("sync:WaitGroupmisuse:AddcalledconcurrentlywithWait")}ifv>0||w==0{return}if*statep!=state{panic("sync:WaitGroupmisuse:AddcalledconcurrentlywithWait")}//如果执行到这里,一定是counter=0,waiter>0//能执行到这里,肯定是执行了Add(-x)的goroutine//它的执行意味着所有的子goroutine都完成了任务//因此,我们需要将所有的复合状态重置为0并释放信号量*statepofnumberofwaiters=0for;w!=0;w--{//释放信号量,一旦执行就会唤醒一个阻塞的waiterruntime_Semrelease(semap,false,0)}}代码非常简洁,接下来分析关键部分。state:=atomic.AddUint64(statep,uint64(delta)<<32)//添加计数器值deltav:=int32(state>>32)//获取计数器值w:=uint32(state)//获取服务员值此时的statep是一个uint64值。如果statep中包含的counter数为2,waiter为1,inputdelta为1,那么这三行代码的逻辑过程如下图所示。得到当前柜台编号v和服务员编号w后,会分几种情况判断它们的值。//情况1:这是一个非常低级的错误,计数器值不能为负ifv<0{panic("sync:negativeWaitGroupcounter")}//情况2:误用导致恐慌//因为wg实际上可以bereused是用的,但是下次复用的基础是所有状态都需要重置为0ifw!=0&&delta>0&&v==int32(delta){panic("sync:WaitGroupmisuse:AddcalledconcurrentlywithWait")}//情况三:本次Add操作只负责增加计数器值,直接返回即可。//如果此时counter值大于0,则唤醒操作留给后续的Addcaller使用(执行Add(negativeint))//如果waiter值为0,表示没有阻塞waiter此时ifv>0||w==0{return}//情况4:误用导致panicif*statep!=state{panic("sync:WaitGroupmisuse:AddcalledconcurrentlywithWait")}关于误用重用导致panic的情况,if没有示例错误代码其实更难解释。好消息是Go源代码中有错误使用的示例,这些示例位于src/sync/waitgroup_test.go文件中。想深入了解的读者可以看下面三个测试函数中的例子。funcTestWaitGroupMisuse(t*testing.T)funcTestWaitGroupMisuse2(t*testing.T)funcTestWaitGroupMisuse3(t*testing.T)Done函数Done()函数比较简单,就是调用Add(-1)。在实际使用中,当子goroutine任务完成后,应该调用Done()函数。func(wg*WaitGroup)Done(){wg.Add(-1)}Waitfunction如果WaitGroup中的计数器值大于0,那么执行Wait()函数的主goroutine会将waiter值加1并阻塞等待值为0,则后续代码可以继续执行。我们将racedetection部分的代码去掉,Wait()函数的源码如下LoadUint64(statep)//原子读取复合状态statepv:=int32(state>>32)//获取计数器值w:=uint32(state)//获取waiter值//如果此时v==0次,证明没有pendingtask子goroutine可以直接退出。ifv==0{return}//如果在执行CAS原子操作和读取复合状态之间没有其他goroutine改变复合状态//然后将waiter值增加1,否则:进入下一轮循环并开始再次读取复合状态ifatomic.CompareAndSwapUint64(statep,state,state+1){//成功累加waiter值后//等待Add函数调用runtime_Semrelease唤醒自己runtime_Semacquire(semap)//重用triggerpanic//在当前goroutine被唤醒的时候,因为自己唤醒的goroutine调用了Add方法//复位操作已经通过*statep=0语句完成//此时复合状态位不为0,因为Waiter还没有执行完Wait,WaitGroup已经被重用了if*statep!=0{panic("sync:WaitGroupisreusedbeforepreviousWaithasreturned")}return}}}总结要了解WaitGroup的源码实现,我们需要一些前置知识,例如信号量、内存对齐、原子操作、移位操作和指针转换。但其实WaitGroup的实现思路还是挺简单的。通过结构域state1维护了两个计数器和一个信号量。计数器是通过Add()添加的子goroutine的计数值计数器,以及通过Wait()阻塞的waiter的计数,信号量用于阻塞和唤醒Waiter。执行Add(positiven)时,counter+=n,表示添加了n个子goroutine执行任务。每个子goroutine完成任务后,需要调用Done()函数将counter的值减1,当最后一个子goroutine完成后,counter的值为0,此时需要唤醒阻塞在Wait()调用中的Waiter。但是,在使用WaitGroup时,有几点需要注意,通过Add()函数添加的计数器的个数必须与通过Done()减去的值保持一致。如果前者很大,阻塞在Wait()调用处的goroutine将永远不会被唤醒;如果后者很大,就会引发恐慌。Add()的增量函数应该首先执行。不要复制WaitGroup对象。如果要重用WaitGroup,则必须在所有先前的Wait()调用返回后进行新的Add()调用。