python中闭包
在函数内部定义的函数和外部定义的函数是一样的,只是他们无法被外部访问:
def g(): print ‘g()...‘ def f(): print ‘f()...‘ return g
将 g 的定义移入函数 f 内部,防止其他代码调用 g:
1 def f(): 2 print ‘f()...‘ 3 def g(): 4 print ‘g()...‘ 5 return g
但是,考察上一小节定义的 calc_sum 函数:
1 def calc_sum(lst): 2 def lazy_sum(): 3 return sum(lst) 4 return lazy_sum
注意: 发现没法把 lazy_sum 移到 calc_sum 的外部,因为它引用了 calc_sum 的参数 lst。
像这种内层函数引用了外层函数的变量(参数也算变量),然后返回内层函数的情况,称为闭包(Closure)。
闭包的特点是返回的函数还引用了外层函数的局部变量,所以,要正确使用闭包,就要确保引用的局部变量在函数返回后不能变。举例如下:
1 # 希望一次返回3个函数,分别计算1x1,2x2,3x3: 2 def count(): 3 fs = [] 4 for i in range(1, 4): 5 def f(): 6 return i*i 7 fs.append(f) 8 return fs 9 10 f1, f2, f3 = count()
你可能认为调用f1(),f2()和f3()结果应该是1,4,9,但实际结果全部都是 9(请自己动手验证)。
原因就是当count()函数返回了3个函数时,这3个函数所引用的变量 i 的值已经变成了3。由于f1、f2、f3并没有被调用,所以,此时他们并未计算 i*i,当 f1 被调用时:
1 >>> f1() 2 9 # 因为f1现在才计算i*i,但现在i的值已经变为3
因此,返回函数不要引用任何循环变量,或者后续会发生变化的变量。
任务
返回闭包不能引用循环变量,请改写count()函数,让它正确返回能计算1x1、2x2、3x3的函数。
源代码的问题:
def count(): fs = [] for i in range(1, 4): def f(): return i*i fs.append(f) return fs f1, f2, f3 = count() print f1(), f2(), f3()#调用返回函数
注意到 fs.append(f)中传入的为一个函数f,即在list中对应着三个函数地址(试试print f1,f2,f3 会显示为三个地址),调用f1,f2,f3 时则都对应着调用f函数,此时for循环已经结束,最终赋值为3,即return 3*3。
尝试改写函数: fs.append(f()),此时传入list为三个值1,4,9.即[1,4,9],当执行f1,f2,f3=count()语句时,为依次取出list表中的值,即1,4,9
def count(): fs = [] for i in range(1, 4): def f(): return i*i fs.append(f()) return fs f1, f2, f3 = count()#这里f1,f2,f3都是list中的值,1,4,9 print f1,f2,f3#打印值
上面这种改写方式与我们的题目要求,改写count()函数,让它能正确返回函数,并且最后调用返回函数,打印出1,4,9不符合。我们采用下面这样的方式改写:
1 def f(j): 2 def g(): 3 return j*j 4 return g
它可以正确地返回一个闭包g,g所引用的变量j不是循环变量,因此将正常执行。
在count函数的循环内部,如果借助f函数,就可以避免引用循环变量i。
def count(): fs = [] for i in range(1, 4): def f(j):#借助一个新的函数f,定义了非循环变量j def g(): return j*j return g r = f(i) fs.append(r) return fs f1, f2, f3 = count() print f1(), f2(), f3()
第二行是定义了一个list,名称为fs,第三行生成1,2,3序列,并且赋值给i。
从第八行看起,当i=1时,将i传入f()之中,f中的j就被赋值了j=i=1,
f()中又有一个函数g(),将参数j传入,在g()中返回j*j,也就是1*1,然后f()中又返回g()(注意这里返回的是一个函数,并不是函数值),也就是g()其中g中的j已经被传入外层函数的值j=1,将g()这个值赋值给变量r,再将r添加到list中 ,在count()中将list返回,此时list为 fs=[g()],f1=fs=[g()],其中g()中j*j变为1*1,打印的时候f1()运行的是
def g():
return 1*1
然后在当i=2和i=3时,继续上面的循环。