PySpark 02 - Closure and Persistence

1 Closure

A task’s closure is those variables and methods which must be visible for the executor to perform its computations on the RDD.

  • Functions that run on RDDs at executors
  • Any global variables used by those executors

The variables within the closure sent to each executor are copies.

This closure is serialized and sent to each executor from the driver when an action is invoked.

For example:

1
2
3
4
5
6
7
8
9
10
11
12
counter = 0
rdd = sc.parallelize(range(10))

def increment_counter(x):
global counter
counter += x

print(rdd.collect())
rdd.foreach(increment_counter)

print(counter)
print(rdd.sum())

The output will be:

1
2
3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
0
45

Accumulators

Accumulators are variables that are only “added” to through an associative and commutative operation. Created from an initial value v by calling SparkContext.accumulator(v). Tasks running on a cluster can then add to it using the add method or the += operator. Only the driver program can read the accumulator’s value, using its value method.

1
2
3
4
5
6
7
8
9
10
rdd = sc.parallelize(range(10))
accum = sc.accumulator(0)

def g(x):
global accum
accum += x

a = rdd.foreach(g)

print(accum.value)

The output will be:

1
45

The workers are never allowed to access the value of “accum”, therefore accum cannot exist on the right side of “+=”. And only “+=” can be used, “+” and “=” lead to errors.

Accumulator is not recommended since most tasks that accumulators do can be accomplished by other better methods. For example, the above code can be replaced with “.sum()” of “.reduce()”.

2 Example: Computing Pi using Monte Cario simulation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# From the official spark examples.
from random import random
from operator import add

partitions = 100
n = 100000 * partitions

def f(_: int) -> float:
x = random() * 2 - 1
y = random() * 2 - 1
return 1 if x ** 2 + y ** 2 <= 1 else 0

count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
print("Pi is roughly %f" % (4.0 * count / n))

glom: Turn a list consists of partitions into a list of lists:

1
2
3
4
5
6
7
8
# Example: glom
import sys
from random import random

a = sc.parallelize(range(0,100),10)
print(a.collect())
print(a.glom().collect())
print(a.map(lambda _: random()).glom().collect())

Output:

1
2
3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30, 31, 32, 33, 34, 35, 36, 37, 38, 39], [40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59], [60, 61, 62, 63, 64, 65, 66, 67, 68, 69], [70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [80, 81, 82, 83, 84, 85, 86, 87, 88, 89], [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]]
[[0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638], [0.028293184296565688, 0.9415562209627079, 0.5093280586882438, 0.8218821565762561, 0.44762257639387604, 0.4226765740328565, 0.5854233995480139, 0.627818147276162, 0.7019893766838969, 0.6376178566256638]]

It shows that the random number in different partitions are the same. In this case, more partitions and more samples cannot improve the performance of the Monte Cario Simulation.

How to fix it:

IntroducemapPartition() and mapPartitionWithIndex():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Example: mapPartition and mapPartitionWithIndex
a = sc.parallelize(range(0,20),4)
print(a.glom().collect())

# Compute the prefix sum of a list
def f(it):
s = 0
l = []
for i in it:
s += i
yield s

print(a.mapPartitions(f).collect())

def f(index, it):
s = index
for i in it:
s += i
yield s

print(a.mapPartitionsWithIndex(f).collect())

Therefore, with mapPartitionWithIndex:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Correct version for computing Pi
from random import random, seed
from time import time

partitions = 100
n = 100000 * partitions

s = time()

def f(index, it):
seed(index + s)
for i in it:
x = random() * 2 - 1
y = random() * 2 - 1
yield 1 if x ** 2 + y ** 2 <= 1 else 0

count = sc.parallelize(range(1, n + 1), partitions).mapPartitionsWithIndex(f).sum()

print("Pi is roughly", 4.0 * count / n)