2018-08-15
RDD,全称 Resilient Distributed Dataset,是 Spark 中最基础的分布式数据抽象。虽然现在很多项目更常用 DataFrame 和 Dataset,但理解 RDD 仍然很有价值。
这篇文章主要用 Python 代码演示 PySpark 中常见 RDD 的用法,包括 RDD 创建、常用转换算子、行动算子、键值 RDD、缓存、分区和一些实战注意点。
可以把 RDD 理解成一个分布式集合。
普通 Python 列表只存在一台机器的内存里,而 RDD 的数据会被切成多个分区,分布在不同 Executor 上并行计算。
例如:
RDD
├── partition-0
├── partition-1
├── partition-2
└── partition-3
RDD 有几个很重要的特点:
所谓惰性执行,就是你写了 map、filter、flatMap 这类操作,Spark 不会马上运行。只有遇到 collect、count、take、saveAsTextFile 这类行动算子时,任务才会真正提交。
在 PySpark 中,一般可以通过 SparkSession 拿到 SparkContext:
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("rdd-demo") \
.getOrCreate()
sc = spark.sparkContext
后面的 RDD 示例都基于这个 sc。
最简单的方式是从 Python 集合创建:
nums = sc.parallelize([1, 2, 3, 4, 5])
也可以指定分区数:
nums = sc.parallelize([1, 2, 3, 4, 5], 2)
print(nums.getNumPartitions())
读取文本文件:
lines = sc.textFile("hdfs:///data/access.log")
如果是本地测试,也可以读取本地文件:
lines = sc.textFile("file:///tmp/access.log")
textFile 读取后得到的是一个 RDD[str],每个元素是一行文本。
map 会对 RDD 中的每个元素执行一个函数,并返回新的 RDD。
nums = sc.parallelize([1, 2, 3, 4])
result = nums.map(lambda x: x * 10)
print(result.collect())
输出:
[10, 20, 30, 40]
map 很适合做字段转换、类型转换、简单计算。
比如从日志中取 IP:
lines = sc.parallelize([
"10.0.0.1 /home 200",
"10.0.0.2 /detail 404",
"10.0.0.3 /pay 500",
])
ips = lines.map(lambda line: line.split(" ")[0])
print(ips.collect())
flatMap 和 map 很像,但它可以把一个元素拆成多个元素。
单词拆分是最经典的例子:
lines = sc.parallelize([
"hello spark",
"hello rdd",
])
words = lines.flatMap(lambda line: line.split(" "))
print(words.collect())
输出:
['hello', 'spark', 'hello', 'rdd']
如果用 map:
result = lines.map(lambda line: line.split(" "))
print(result.collect())
输出会是:
[['hello', 'spark'], ['hello', 'rdd']]
map 保留嵌套结构,flatMap 会把嵌套结构打平。
filter 用来保留满足条件的数据。
nums = sc.parallelize([1, 2, 3, 4, 5, 6])
even = nums.filter(lambda x: x % 2 == 0)
print(even.collect())
输出:
[2, 4, 6]
过滤日志也很常见:
error_logs = lines.filter(lambda line: "ERROR" in line)
filter 不改变元素结构,只改变数据量。
distinct 用于去重:
nums = sc.parallelize([1, 1, 2, 2, 3, 3])
unique = nums.distinct()
print(unique.collect())
输出:
[1, 2, 3]
需要注意,distinct 通常会触发 Shuffle。数据量很大时,不要随手乱用。
RDD 支持一些集合类操作。
rdd1 = sc.parallelize([1, 2, 3])
rdd2 = sc.parallelize([3, 4, 5])
合并:
print(rdd1.union(rdd2).collect())
输出:
[1, 2, 3, 3, 4, 5]
union 不会自动去重。
取交集:
print(rdd1.intersection(rdd2).collect())
取差集:
print(rdd1.subtract(rdd2).collect())
这些操作看起来简单,但大多涉及 Shuffle,数据量大时要关注性能。
sample 可以从 RDD 中抽取一部分数据。
nums = sc.parallelize(range(100))
sampled = nums.sample(
withReplacement=False,
fraction=0.1
)
print(sampled.take(10))
参数说明:
withReplacement:是否放回抽样fraction:抽样比例抽样很适合做数据预览、调试、估算数据分布。
groupBy 可以根据函数返回值分组。
nums = sc.parallelize([1, 2, 3, 4, 5, 6])
grouped = nums.groupBy(lambda x: x % 2)
result = grouped.mapValues(list).collect()
print(result)
输出类似:
[(0, [2, 4, 6]), (1, [1, 3, 5])]
groupBy 用起来方便,但它会把相同分组的数据拉到一起。如果某个分组特别大,容易造成内存压力。
如果只是做求和、计数,通常更推荐使用 reduceByKey。
常见行动算子有:
rdd.collect()
rdd.count()
rdd.first()
rdd.take(10)
rdd.reduce(lambda a, b: a + b)
rdd.foreach(lambda x: print(x))
rdd.saveAsTextFile("hdfs:///output/path")
collect 会把所有数据拉回 Driver:
data = rdd.collect()
数据量小时它很好用。数据量大时,它可能把 Driver 内存打爆。
调试时更推荐:
print(rdd.take(10))
或者:
print(rdd.sample(False, 0.01).take(20))
reduce 可以把 RDD 中的元素合并成一个结果。
nums = sc.parallelize([1, 2, 3, 4])
total = nums.reduce(lambda a, b: a + b)
print(total)
输出:
10
reduce 的函数最好满足结合律,否则分布式执行时结果可能不稳定。
适合的场景包括:
foreach 会在 Executor 端执行函数。
rdd.foreach(lambda x: print(x))
初学时很容易写出这种代码:
count = 0
def add_count(x):
global count
count += 1
rdd.foreach(add_count)
print(count)
这通常不会得到你想要的结果。因为 Executor 里执行的是 Driver 变量的副本,不会直接修改 Driver 端的 count。
要统计数量,用:
print(rdd.count())
要做辅助统计,可以使用累加器。
在 PySpark 里,键值对通常用 tuple 表示。
pairs = sc.parallelize([
("spark", 1),
("hadoop", 1),
("spark", 1),
])
这种结构可以使用很多按 key 处理的算子,比如 reduceByKey、groupByKey、join、sortByKey。
单词统计是 reduceByKey 的经典例子:
lines = sc.parallelize([
"hello spark",
"hello rdd",
"spark rdd",
])
word_count = (
lines
.flatMap(lambda line: line.split(" "))
.map(lambda word: (word, 1))
.reduceByKey(lambda a, b: a + b)
)
print(word_count.collect())
输出类似:
[('hello', 2), ('spark', 2), ('rdd', 2)]
reduceByKey 会先在分区内做本地聚合,再进行 Shuffle,通常比 groupByKey 更适合聚合统计。
groupByKey 会把相同 key 的所有 value 放到一起。
pairs = sc.parallelize([
("spark", 1),
("spark", 1),
("hadoop", 1),
])
grouped = pairs.groupByKey()
print(grouped.mapValues(list).collect())
输出:
[('spark', [1, 1]), ('hadoop', [1])]
如果只是为了求和,不推荐这样写:
result = pairs.groupByKey().mapValues(sum)
更推荐:
result = pairs.reduceByKey(lambda a, b: a + b)
groupByKey 会把所有 value 都拉到同一个 key 下,大 key 很容易造成内存压力。
mapValues 会保留 key,只转换 value。
scores = sc.parallelize([
("Tom", 80),
("Jerry", 90),
])
result = scores.mapValues(lambda score: score + 10)
print(result.collect())
输出:
[('Tom', 90), ('Jerry', 100)]
这种写法比手动拆 tuple 更清晰。
data = sc.parallelize([
("user1", "redis,spark,hadoop"),
("user2", "spark,kafka"),
])
result = data.flatMapValues(lambda tags: tags.split(","))
print(result.collect())
输出:
[('user1', 'redis'), ('user1', 'spark'), ('user1', 'hadoop'), ('user2', 'spark'), ('user2', 'kafka')]
它适合标签展开、明细展开这类场景。
按 key 排序:
data = sc.parallelize([
(3, "c"),
(1, "a"),
(2, "b"),
])
sorted_data = data.sortByKey()
print(sorted_data.collect())
降序:
sorted_desc = data.sortByKey(ascending=False)
如果不是键值 RDD,可以用 sortBy:
users = sc.parallelize([
("Tom", 20),
("Jerry", 18),
("Alice", 25),
])
sorted_users = users.sortBy(lambda x: x[1])
print(sorted_users.collect())
排序通常会触发 Shuffle,大数据量排序要谨慎。
RDD 可以按 key 做连接。
users = sc.parallelize([
(1, "Tom"),
(2, "Jerry"),
])
orders = sc.parallelize([
(1, "order-a"),
(1, "order-b"),
(2, "order-c"),
])
joined = users.join(orders)
print(joined.collect())
输出类似:
[(1, ('Tom', 'order-a')), (1, ('Tom', 'order-b')), (2, ('Jerry', 'order-c'))]
常见连接还有:
users.leftOuterJoin(orders)
users.rightOuterJoin(orders)
users.fullOuterJoin(orders)
两个大 RDD 做 join 时,通常会产生 Shuffle。要特别注意 key 倾斜和分区数量。
如果其中一个数据集很小,可以考虑广播变量。
cogroup 可以把多个 RDD 中相同 key 的 value 聚到一起。
rdd1 = sc.parallelize([
("a", 1),
("b", 2),
])
rdd2 = sc.parallelize([
("a", 10),
("a", 20),
("c", 30),
])
result = rdd1.cogroup(rdd2)
print(result.mapValues(lambda x: (list(x[0]), list(x[1]))).collect())
输出类似:
[('a', ([1], [10, 20])), ('b', ([2], [])), ('c', ([], [30]))]
它适合多个来源按同一个 key 汇总处理的场景。
如果输入 value 类型和输出 value 类型不同,可以使用 aggregateByKey。
例如计算每个班级的平均分:
scores = sc.parallelize([
("class-a", 80),
("class-a", 90),
("class-b", 70),
])
sum_count = scores.aggregateByKey(
(0, 0),
lambda acc, score: (acc[0] + score, acc[1] + 1),
lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1])
)
avg = sum_count.mapValues(lambda x: x[0] / x[1])
print(avg.collect())
aggregateByKey 的三个核心部分:
它比 reduceByKey 灵活,但也更容易写错。
combineByKey 可以更细粒度地控制按 key 聚合过程。
还是计算平均分:
scores = sc.parallelize([
("Tom", 80),
("Tom", 90),
("Jerry", 70),
])
combined = scores.combineByKey(
lambda score: (score, 1),
lambda acc, score: (acc[0] + score, acc[1] + 1),
lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1])
)
avg = combined.mapValues(lambda x: x[0] / x[1])
print(avg.collect())
三个函数分别表示:
如果 reduceByKey 或 aggregateByKey 已经能表达清楚,就不用强行上 combineByKey。
map 是每条数据调用一次函数。
mapPartitions 是每个分区调用一次函数。
nums = sc.parallelize([1, 2, 3, 4], 2)
result = nums.mapPartitions(lambda iterator: (x * 2 for x in iterator))
print(result.collect())
它适合按分区初始化资源,比如每个分区创建一次外部连接:
def process_partition(iterator):
conn = create_connection()
try:
for record in iterator:
yield query_by_connection(conn, record)
finally:
conn.close()
result = rdd.mapPartitions(process_partition)
注意不要随便把整个 iterator 转成 list。分区数据过大时,内存会很难受。
写数据库、搜索引擎、消息系统时,foreachPartition 往往比 foreach 更合适。
不推荐每条数据都创建一次连接:
def write_one(record):
conn = create_connection()
try:
write_record(conn, record)
finally:
conn.close()
rdd.foreach(write_one)
更推荐每个分区创建一次连接:
def write_partition(iterator):
conn = create_connection()
try:
for record in iterator:
write_record(conn, record)
finally:
conn.close()
rdd.foreachPartition(write_partition)
这样可以明显减少连接创建开销。
查看分区数:
print(rdd.getNumPartitions())
减少分区:
small = rdd.coalesce(2)
重新分区:
more = rdd.repartition(10)
repartition 会触发 Shuffle。coalesce 默认不触发 Shuffle,更适合在过滤后数据量变少时减少分区。
例如减少输出小文件:
filtered = big_rdd.filter(lambda x: x is not None)
filtered.coalesce(4).saveAsTextFile("hdfs:///output/result")
RDD 默认不会自动保存中间结果。如果一个 RDD 被多个行动算子重复使用,最好缓存。
parsed = lines.map(parse_line).filter(lambda x: x is not None)
parsed.cache()
print(parsed.count())
print(parsed.take(10))
cache 等价于使用默认内存级别。
也可以使用 persist 指定存储级别:
from pyspark import StorageLevel
parsed.persist(StorageLevel.MEMORY_AND_DISK)
不用了可以释放:
parsed.unpersist()
缓存适合复用数据,不适合见到 RDD 就缓存。缓存太多会占用 Executor 内存。
如果一个 RDD 的转换链很长,可以使用 checkpoint 截断血缘关系。
先设置 checkpoint 目录:
sc.setCheckpointDir("hdfs:///spark/checkpoint")
使用:
result = rdd.map(func1).filter(func2).reduceByKey(lambda a, b: a + b)
result.cache()
result.checkpoint()
result.count()
checkpoint 需要行动算子触发。和 cache 一起使用,通常可以减少重复计算。
如果每个任务都要使用一份小数据,可以使用广播变量。
例如一个小字典:
country_dict = {
"cn": "China",
"us": "United States",
}
bc_country = sc.broadcast(country_dict)
codes = sc.parallelize(["cn", "us", "unknown"])
result = codes.map(lambda code: bc_country.value.get(code, "Unknown"))
print(result.collect())
广播变量适合:
不要广播太大的对象,否则 Driver 和 Executor 都会有压力。
累加器适合做辅助计数。
invalid_count = sc.accumulator(0)
def parse_line(line):
parts = line.split(" ")
if len(parts) < 3:
invalid_count.add(1)
return None
return parts
parsed = lines.map(parse_line).filter(lambda x: x is not None)
parsed.count()
print(invalid_count.value)
累加器适合做监控和统计,不适合参与关键业务逻辑。
原因是任务失败重试时,累加器可能出现重复累加,需要谨慎理解它的值。
假设日志格式是:
ip url status
例如:
10.0.0.1 /home 200
10.0.0.2 /detail 200
10.0.0.1 /home 500
统计每个 URL 的访问次数:
lines = sc.textFile("hdfs:///data/access.log")
url_count = (
lines
.map(lambda line: line.split(" "))
.filter(lambda arr: len(arr) >= 3)
.map(lambda arr: (arr[1], 1))
.reduceByKey(lambda a, b: a + b)
)
url_count.saveAsTextFile("hdfs:///output/url-count")
统计状态码数量:
status_count = (
lines
.map(lambda line: line.split(" "))
.filter(lambda arr: len(arr) >= 3)
.map(lambda arr: (arr[2], 1))
.reduceByKey(lambda a, b: a + b)
)
print(status_count.collect())
取访问量最高的几个 URL:
top_urls = (
url_count
.map(lambda x: (x[1], x[0]))
.sortByKey(ascending=False)
.take(10)
)
print(top_urls)
如果只是在 Driver 端取少量结果,take 是可以接受的。不要用 collect 拉全量结果。
假设数据格式是:
user_id tag1,tag2,tag3
示例数据:
data = sc.parallelize([
"u1 redis,spark,hadoop",
"u2 spark,kafka",
"u3 redis",
])
展开成 (user_id, tag):
user_tags = (
data
.map(lambda line: line.split(" "))
.filter(lambda arr: len(arr) == 2)
.map(lambda arr: (arr[0], arr[1]))
.flatMapValues(lambda tags: tags.split(","))
)
print(user_tags.collect())
统计每个标签出现次数:
tag_count = (
user_tags
.map(lambda x: (x[1], 1))
.reduceByKey(lambda a, b: a + b)
)
print(tag_count.collect())
RDD 更灵活,适合处理结构不固定的数据,也适合写一些底层控制更多的逻辑。
DataFrame 更适合结构化数据处理。它有优化器,很多查询和聚合可以自动优化,代码也更接近 SQL。
如果是常规 ETL、报表、聚合分析,DataFrame 往往更省心。
如果是学习 Spark 原理、处理非结构化数据、或者需要非常自由的函数式处理方式,RDD 仍然很有价值。
真正写好 RDD 任务,关键不只是会调用 API,还要知道哪些操作会触发 Shuffle,哪些操作会把数据拉回 Driver,哪些场景会造成分区不均或内存压力。
理解这些之后,RDD 就不只是一个 API 集合,而是一套清晰的分布式计算思路。