spark rdd快速上手




2018-08-15

blog_main_img

RDD,全称 Resilient Distributed Dataset,是 Spark 中最基础的分布式数据抽象。虽然现在很多项目更常用 DataFrame 和 Dataset,但理解 RDD 仍然很有价值。

这篇文章主要用 Python 代码演示 PySpark 中常见 RDD 的用法,包括 RDD 创建、常用转换算子、行动算子、键值 RDD、缓存、分区和一些实战注意点。

RDD 可以理解成什么

可以把 RDD 理解成一个分布式集合。

普通 Python 列表只存在一台机器的内存里,而 RDD 的数据会被切成多个分区,分布在不同 Executor 上并行计算。

例如:

RDD
├── partition-0
├── partition-1
├── partition-2
└── partition-3

RDD 有几个很重要的特点:

  • 数据分布在多个分区中
  • 转换操作是惰性执行的
  • 行动操作才会真正触发计算
  • 可以缓存重复使用的数据
  • 可以通过血缘关系进行容错恢复

所谓惰性执行,就是你写了 mapfilterflatMap 这类操作,Spark 不会马上运行。只有遇到 collectcounttakesaveAsTextFile 这类行动算子时,任务才会真正提交。

准备一个 SparkContext

在 PySpark 中,一般可以通过 SparkSession 拿到 SparkContext

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("rdd-demo") \
    .getOrCreate()

sc = spark.sparkContext

后面的 RDD 示例都基于这个 sc

创建 RDD

最简单的方式是从 Python 集合创建:

nums = sc.parallelize([1, 2, 3, 4, 5])

也可以指定分区数:

nums = sc.parallelize([1, 2, 3, 4, 5], 2)

print(nums.getNumPartitions())

读取文本文件:

lines = sc.textFile("hdfs:///data/access.log")

如果是本地测试,也可以读取本地文件:

lines = sc.textFile("file:///tmp/access.log")

textFile 读取后得到的是一个 RDD[str],每个元素是一行文本。

map:一条进,一条出

map 会对 RDD 中的每个元素执行一个函数,并返回新的 RDD。

nums = sc.parallelize([1, 2, 3, 4])

result = nums.map(lambda x: x * 10)

print(result.collect())

输出:

[10, 20, 30, 40]

map 很适合做字段转换、类型转换、简单计算。

比如从日志中取 IP:

lines = sc.parallelize([
    "10.0.0.1 /home 200",
    "10.0.0.2 /detail 404",
    "10.0.0.3 /pay 500",
])

ips = lines.map(lambda line: line.split(" ")[0])

print(ips.collect())

flatMap:一条进,多条出

flatMapmap 很像,但它可以把一个元素拆成多个元素。

单词拆分是最经典的例子:

lines = sc.parallelize([
    "hello spark",
    "hello rdd",
])

words = lines.flatMap(lambda line: line.split(" "))

print(words.collect())

输出:

['hello', 'spark', 'hello', 'rdd']

如果用 map

result = lines.map(lambda line: line.split(" "))

print(result.collect())

输出会是:

[['hello', 'spark'], ['hello', 'rdd']]

map 保留嵌套结构,flatMap 会把嵌套结构打平。

filter:过滤数据

filter 用来保留满足条件的数据。

nums = sc.parallelize([1, 2, 3, 4, 5, 6])

even = nums.filter(lambda x: x % 2 == 0)

print(even.collect())

输出:

[2, 4, 6]

过滤日志也很常见:

error_logs = lines.filter(lambda line: "ERROR" in line)

filter 不改变元素结构,只改变数据量。

distinct:去重

distinct 用于去重:

nums = sc.parallelize([1, 1, 2, 2, 3, 3])

unique = nums.distinct()

print(unique.collect())

输出:

[1, 2, 3]

需要注意,distinct 通常会触发 Shuffle。数据量很大时,不要随手乱用。

union、intersection、subtract

RDD 支持一些集合类操作。

rdd1 = sc.parallelize([1, 2, 3])
rdd2 = sc.parallelize([3, 4, 5])

合并:

print(rdd1.union(rdd2).collect())

输出:

[1, 2, 3, 3, 4, 5]

union 不会自动去重。

取交集:

print(rdd1.intersection(rdd2).collect())

取差集:

print(rdd1.subtract(rdd2).collect())

这些操作看起来简单,但大多涉及 Shuffle,数据量大时要关注性能。

sample:抽样

sample 可以从 RDD 中抽取一部分数据。

nums = sc.parallelize(range(100))

sampled = nums.sample(
    withReplacement=False,
    fraction=0.1
)

print(sampled.take(10))

参数说明:

  • withReplacement:是否放回抽样
  • fraction:抽样比例

抽样很适合做数据预览、调试、估算数据分布。

groupBy:按规则分组

groupBy 可以根据函数返回值分组。

nums = sc.parallelize([1, 2, 3, 4, 5, 6])

grouped = nums.groupBy(lambda x: x % 2)

result = grouped.mapValues(list).collect()

print(result)

输出类似:

[(0, [2, 4, 6]), (1, [1, 3, 5])]

groupBy 用起来方便,但它会把相同分组的数据拉到一起。如果某个分组特别大,容易造成内存压力。

如果只是做求和、计数,通常更推荐使用 reduceByKey

行动算子:真正触发计算

常见行动算子有:

rdd.collect()
rdd.count()
rdd.first()
rdd.take(10)
rdd.reduce(lambda a, b: a + b)
rdd.foreach(lambda x: print(x))
rdd.saveAsTextFile("hdfs:///output/path")

collect 会把所有数据拉回 Driver:

data = rdd.collect()

数据量小时它很好用。数据量大时,它可能把 Driver 内存打爆。

调试时更推荐:

print(rdd.take(10))

或者:

print(rdd.sample(False, 0.01).take(20))

reduce:聚合元素

reduce 可以把 RDD 中的元素合并成一个结果。

nums = sc.parallelize([1, 2, 3, 4])

total = nums.reduce(lambda a, b: a + b)

print(total)

输出:

10

reduce 的函数最好满足结合律,否则分布式执行时结果可能不稳定。

适合的场景包括:

  • 求和
  • 求最大值
  • 求最小值
  • 合并结构一致的数据

foreach:不要误会它运行在哪里

foreach 会在 Executor 端执行函数。

rdd.foreach(lambda x: print(x))

初学时很容易写出这种代码:

count = 0

def add_count(x):
    global count
    count += 1

rdd.foreach(add_count)

print(count)

这通常不会得到你想要的结果。因为 Executor 里执行的是 Driver 变量的副本,不会直接修改 Driver 端的 count

要统计数量,用:

print(rdd.count())

要做辅助统计,可以使用累加器。

Pair RDD:键值对 RDD

在 PySpark 里,键值对通常用 tuple 表示。

pairs = sc.parallelize([
    ("spark", 1),
    ("hadoop", 1),
    ("spark", 1),
])

这种结构可以使用很多按 key 处理的算子,比如 reduceByKeygroupByKeyjoinsortByKey

reduceByKey:按 key 聚合

单词统计是 reduceByKey 的经典例子:

lines = sc.parallelize([
    "hello spark",
    "hello rdd",
    "spark rdd",
])

word_count = (
    lines
    .flatMap(lambda line: line.split(" "))
    .map(lambda word: (word, 1))
    .reduceByKey(lambda a, b: a + b)
)

print(word_count.collect())

输出类似:

[('hello', 2), ('spark', 2), ('rdd', 2)]

reduceByKey 会先在分区内做本地聚合,再进行 Shuffle,通常比 groupByKey 更适合聚合统计。

groupByKey:能用,但要谨慎

groupByKey 会把相同 key 的所有 value 放到一起。

pairs = sc.parallelize([
    ("spark", 1),
    ("spark", 1),
    ("hadoop", 1),
])

grouped = pairs.groupByKey()

print(grouped.mapValues(list).collect())

输出:

[('spark', [1, 1]), ('hadoop', [1])]

如果只是为了求和,不推荐这样写:

result = pairs.groupByKey().mapValues(sum)

更推荐:

result = pairs.reduceByKey(lambda a, b: a + b)

groupByKey 会把所有 value 都拉到同一个 key 下,大 key 很容易造成内存压力。

mapValues:只处理 value

mapValues 会保留 key,只转换 value。

scores = sc.parallelize([
    ("Tom", 80),
    ("Jerry", 90),
])

result = scores.mapValues(lambda score: score + 10)

print(result.collect())

输出:

[('Tom', 90), ('Jerry', 100)]

这种写法比手动拆 tuple 更清晰。

flatMapValues:一个 value 拆成多个 value

data = sc.parallelize([
    ("user1", "redis,spark,hadoop"),
    ("user2", "spark,kafka"),
])

result = data.flatMapValues(lambda tags: tags.split(","))

print(result.collect())

输出:

[('user1', 'redis'), ('user1', 'spark'), ('user1', 'hadoop'), ('user2', 'spark'), ('user2', 'kafka')]

它适合标签展开、明细展开这类场景。

sortByKey 和 sortBy

按 key 排序:

data = sc.parallelize([
    (3, "c"),
    (1, "a"),
    (2, "b"),
])

sorted_data = data.sortByKey()

print(sorted_data.collect())

降序:

sorted_desc = data.sortByKey(ascending=False)

如果不是键值 RDD,可以用 sortBy

users = sc.parallelize([
    ("Tom", 20),
    ("Jerry", 18),
    ("Alice", 25),
])

sorted_users = users.sortBy(lambda x: x[1])

print(sorted_users.collect())

排序通常会触发 Shuffle,大数据量排序要谨慎。

join:连接两个 Pair RDD

RDD 可以按 key 做连接。

users = sc.parallelize([
    (1, "Tom"),
    (2, "Jerry"),
])

orders = sc.parallelize([
    (1, "order-a"),
    (1, "order-b"),
    (2, "order-c"),
])

joined = users.join(orders)

print(joined.collect())

输出类似:

[(1, ('Tom', 'order-a')), (1, ('Tom', 'order-b')), (2, ('Jerry', 'order-c'))]

常见连接还有:

users.leftOuterJoin(orders)
users.rightOuterJoin(orders)
users.fullOuterJoin(orders)

两个大 RDD 做 join 时,通常会产生 Shuffle。要特别注意 key 倾斜和分区数量。

如果其中一个数据集很小,可以考虑广播变量。

cogroup:按 key 聚合多个 RDD

cogroup 可以把多个 RDD 中相同 key 的 value 聚到一起。

rdd1 = sc.parallelize([
    ("a", 1),
    ("b", 2),
])

rdd2 = sc.parallelize([
    ("a", 10),
    ("a", 20),
    ("c", 30),
])

result = rdd1.cogroup(rdd2)

print(result.mapValues(lambda x: (list(x[0]), list(x[1]))).collect())

输出类似:

[('a', ([1], [10, 20])), ('b', ([2], [])), ('c', ([], [30]))]

它适合多个来源按同一个 key 汇总处理的场景。

aggregateByKey:更灵活的按 key 聚合

如果输入 value 类型和输出 value 类型不同,可以使用 aggregateByKey

例如计算每个班级的平均分:

scores = sc.parallelize([
    ("class-a", 80),
    ("class-a", 90),
    ("class-b", 70),
])

sum_count = scores.aggregateByKey(
    (0, 0),
    lambda acc, score: (acc[0] + score, acc[1] + 1),
    lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1])
)

avg = sum_count.mapValues(lambda x: x[0] / x[1])

print(avg.collect())

aggregateByKey 的三个核心部分:

  • 初始值
  • 分区内聚合函数
  • 分区间合并函数

它比 reduceByKey 灵活,但也更容易写错。

combineByKey:更底层的聚合

combineByKey 可以更细粒度地控制按 key 聚合过程。

还是计算平均分:

scores = sc.parallelize([
    ("Tom", 80),
    ("Tom", 90),
    ("Jerry", 70),
])

combined = scores.combineByKey(
    lambda score: (score, 1),
    lambda acc, score: (acc[0] + score, acc[1] + 1),
    lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1])
)

avg = combined.mapValues(lambda x: x[0] / x[1])

print(avg.collect())

三个函数分别表示:

  • 第一次遇到 key 时创建初始聚合值
  • 分区内继续合并 value
  • 分区间合并聚合结果

如果 reduceByKeyaggregateByKey 已经能表达清楚,就不用强行上 combineByKey

mapPartitions:按分区处理

map 是每条数据调用一次函数。

mapPartitions 是每个分区调用一次函数。

nums = sc.parallelize([1, 2, 3, 4], 2)

result = nums.mapPartitions(lambda iterator: (x * 2 for x in iterator))

print(result.collect())

它适合按分区初始化资源,比如每个分区创建一次外部连接:

def process_partition(iterator):
    conn = create_connection()
    try:
        for record in iterator:
            yield query_by_connection(conn, record)
    finally:
        conn.close()

result = rdd.mapPartitions(process_partition)

注意不要随便把整个 iterator 转成 list。分区数据过大时,内存会很难受。

foreachPartition:按分区写外部系统

写数据库、搜索引擎、消息系统时,foreachPartition 往往比 foreach 更合适。

不推荐每条数据都创建一次连接:

def write_one(record):
    conn = create_connection()
    try:
        write_record(conn, record)
    finally:
        conn.close()

rdd.foreach(write_one)

更推荐每个分区创建一次连接:

def write_partition(iterator):
    conn = create_connection()
    try:
        for record in iterator:
            write_record(conn, record)
    finally:
        conn.close()

rdd.foreachPartition(write_partition)

这样可以明显减少连接创建开销。

coalesce 和 repartition

查看分区数:

print(rdd.getNumPartitions())

减少分区:

small = rdd.coalesce(2)

重新分区:

more = rdd.repartition(10)

repartition 会触发 Shuffle。coalesce 默认不触发 Shuffle,更适合在过滤后数据量变少时减少分区。

例如减少输出小文件:

filtered = big_rdd.filter(lambda x: x is not None)
filtered.coalesce(4).saveAsTextFile("hdfs:///output/result")

cache 和 persist

RDD 默认不会自动保存中间结果。如果一个 RDD 被多个行动算子重复使用,最好缓存。

parsed = lines.map(parse_line).filter(lambda x: x is not None)

parsed.cache()

print(parsed.count())
print(parsed.take(10))

cache 等价于使用默认内存级别。

也可以使用 persist 指定存储级别:

from pyspark import StorageLevel

parsed.persist(StorageLevel.MEMORY_AND_DISK)

不用了可以释放:

parsed.unpersist()

缓存适合复用数据,不适合见到 RDD 就缓存。缓存太多会占用 Executor 内存。

checkpoint:截断血缘关系

如果一个 RDD 的转换链很长,可以使用 checkpoint 截断血缘关系。

先设置 checkpoint 目录:

sc.setCheckpointDir("hdfs:///spark/checkpoint")

使用:

result = rdd.map(func1).filter(func2).reduceByKey(lambda a, b: a + b)

result.cache()
result.checkpoint()
result.count()

checkpoint 需要行动算子触发。和 cache 一起使用,通常可以减少重复计算。

广播变量

如果每个任务都要使用一份小数据,可以使用广播变量。

例如一个小字典:

country_dict = {
    "cn": "China",
    "us": "United States",
}

bc_country = sc.broadcast(country_dict)

codes = sc.parallelize(["cn", "us", "unknown"])

result = codes.map(lambda code: bc_country.value.get(code, "Unknown"))

print(result.collect())

广播变量适合:

  • 小表关联
  • 规则字典
  • 配置映射
  • 黑白名单

不要广播太大的对象,否则 Driver 和 Executor 都会有压力。

累加器

累加器适合做辅助计数。

invalid_count = sc.accumulator(0)

def parse_line(line):
    parts = line.split(" ")
    if len(parts) < 3:
        invalid_count.add(1)
        return None
    return parts

parsed = lines.map(parse_line).filter(lambda x: x is not None)

parsed.count()

print(invalid_count.value)

累加器适合做监控和统计,不适合参与关键业务逻辑。

原因是任务失败重试时,累加器可能出现重复累加,需要谨慎理解它的值。

一个完整例子:访问日志统计

假设日志格式是:

ip url status

例如:

10.0.0.1 /home 200
10.0.0.2 /detail 200
10.0.0.1 /home 500

统计每个 URL 的访问次数:

lines = sc.textFile("hdfs:///data/access.log")

url_count = (
    lines
    .map(lambda line: line.split(" "))
    .filter(lambda arr: len(arr) >= 3)
    .map(lambda arr: (arr[1], 1))
    .reduceByKey(lambda a, b: a + b)
)

url_count.saveAsTextFile("hdfs:///output/url-count")

统计状态码数量:

status_count = (
    lines
    .map(lambda line: line.split(" "))
    .filter(lambda arr: len(arr) >= 3)
    .map(lambda arr: (arr[2], 1))
    .reduceByKey(lambda a, b: a + b)
)

print(status_count.collect())

取访问量最高的几个 URL:

top_urls = (
    url_count
    .map(lambda x: (x[1], x[0]))
    .sortByKey(ascending=False)
    .take(10)
)

print(top_urls)

如果只是在 Driver 端取少量结果,take 是可以接受的。不要用 collect 拉全量结果。

一个完整例子:用户标签展开

假设数据格式是:

user_id tag1,tag2,tag3

示例数据:

data = sc.parallelize([
    "u1 redis,spark,hadoop",
    "u2 spark,kafka",
    "u3 redis",
])

展开成 (user_id, tag)

user_tags = (
    data
    .map(lambda line: line.split(" "))
    .filter(lambda arr: len(arr) == 2)
    .map(lambda arr: (arr[0], arr[1]))
    .flatMapValues(lambda tags: tags.split(","))
)

print(user_tags.collect())

统计每个标签出现次数:

tag_count = (
    user_tags
    .map(lambda x: (x[1], 1))
    .reduceByKey(lambda a, b: a + b)
)

print(tag_count.collect())

RDD 和 DataFrame 怎么选

RDD 更灵活,适合处理结构不固定的数据,也适合写一些底层控制更多的逻辑。

DataFrame 更适合结构化数据处理。它有优化器,很多查询和聚合可以自动优化,代码也更接近 SQL。

如果是常规 ETL、报表、聚合分析,DataFrame 往往更省心。

如果是学习 Spark 原理、处理非结构化数据、或者需要非常自由的函数式处理方式,RDD 仍然很有价值。

真正写好 RDD 任务,关键不只是会调用 API,还要知道哪些操作会触发 Shuffle,哪些操作会把数据拉回 Driver,哪些场景会造成分区不均或内存压力。

理解这些之后,RDD 就不只是一个 API 集合,而是一套清晰的分布式计算思路。