All topics
Data · Learning hub

Spark notes for developers

Master Spark with a curated set of 3 developer notes — core concepts, patterns, and interview prep. Maintained by the DevRecall team.

Save this stack to your DevRecallMore Data notes
Spark

DataFrames & Spark SQL

Apache Spark: DataFrames & Spark SQL Apache Spark is a unified analytics engine for large-scale data processing. DataFrames are distributed collections of data

Apache Spark: DataFrames & Spark SQL

Apache Spark is a unified analytics engine for large-scale data processing. DataFrames are distributed collections of data organized into named columns — like a database table or a pandas DataFrame but distributed across a cluster.

SparkSession & Loading Data

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

spark = SparkSession.builder     .appName("MyApp")     .config("spark.sql.adaptive.enabled", "true")     .getOrCreate()

# Read from various sources
df = spark.read.csv("s3://bucket/data/*.csv", header=True, inferSchema=True)
df = spark.read.parquet("hdfs:///data/orders/")
df = spark.read.json("s3://bucket/events/")
df = spark.read.orc("/data/sales/")

# Define schema explicitly (faster than inferSchema for large files)
schema = StructType([
    StructField("order_id", IntegerType(), nullable=False),
    StructField("user_id", StringType(), nullable=True),
    StructField("amount", DoubleType(), nullable=True),
    StructField("status", StringType(), nullable=True),
])
df = spark.read.csv("orders.csv", header=True, schema=schema)

# From JDBC
df = spark.read.jdbc(
    url="jdbc:postgresql://host:5432/mydb",
    table="orders",
    properties={"user": "user", "password": "pass", "driver": "org.postgresql.Driver"}
)

# Write
df.write.mode("overwrite").parquet("output/orders/")
df.write.mode("append").partitionBy("year", "month").parquet("output/partitioned/")
df.write.format("delta").mode("overwrite").save("output/delta/")

DataFrame Operations

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Select & rename
df.select("order_id", "amount", "status")
df.select(F.col("amount").alias("total"), F.col("status").cast("string"))

# Filter
df.filter(F.col("amount") > 100)
df.filter((F.col("status") == "completed") & (F.col("amount") > 50))
df.where("status = 'completed' AND amount > 50")  # SQL string

# Add/transform columns
df.withColumn("tax", F.col("amount") * 0.1)
df.withColumn("year", F.year(F.col("created_at")))
df.withColumn("log_amount", F.log(F.col("amount")))
df.withColumn("status_upper", F.upper(F.col("status")))
df.withColumnRenamed("user_id", "userId")

# Aggregations
df.groupBy("status").agg(
    F.count("*").alias("count"),
    F.sum("amount").alias("total"),
    F.avg("amount").alias("avg"),
    F.max("amount").alias("max_amount"),
    F.collect_list("order_id").alias("order_ids"),
    F.countDistinct("user_id").alias("unique_users"),
)

# Joins
orders.join(users, orders.user_id == users.id, how="inner")
orders.join(users, "user_id", how="left")       # same column name shorthand
orders.join(
    F.broadcast(small_table),  # broadcast hint for small table
    "key", "left"
)

# Window functions
window = Window.partitionBy("user_id").orderBy(F.desc("created_at"))
df.withColumn("rank", F.rank().over(window))   .withColumn("cumsum", F.sum("amount").over(window.rowsBetween(Window.unboundedPreceding, 0)))   .withColumn("prev_amount", F.lag("amount", 1).over(window))

# Sort, limit, distinct
df.orderBy(F.desc("amount"))
df.orderBy("status", F.asc("amount"))
df.limit(100)
df.distinct()
df.dropDuplicates(["user_id"])

# Inspect
df.show(20, truncate=False)
df.printSchema()
df.describe().show()       # basic stats
df.count()

Spark SQL

# Register as temp view and query with SQL
df.createOrReplaceTempView("orders")
users.createOrReplaceTempView("users")

result = spark.sql("""
    SELECT
        u.name,
        COUNT(o.order_id) AS order_count,
        SUM(o.amount)     AS total_spent,
        RANK() OVER (ORDER BY SUM(o.amount) DESC) AS spending_rank
    FROM orders o
    JOIN users u ON o.user_id = u.id
    WHERE o.status = 'completed'
      AND o.created_at >= DATE_SUB(CURRENT_DATE(), 30)
    GROUP BY u.id, u.name
    HAVING SUM(o.amount) > 500
    ORDER BY total_spent DESC
    LIMIT 100
""")

result.show()

# Use catalog for persistent tables (Hive metastore)
df.write.saveAsTable("mydb.orders")
spark.sql("SELECT COUNT(*) FROM mydb.orders").show()
Spark

RDDs, Transformations & Actions

Apache Spark: RDDs, Transformations & Actions RDD Basics RDD (Resilient Distributed Dataset) is the low-level foundation of Spark. Prefer DataFrames for most us

Apache Spark: RDDs, Transformations & Actions

RDD Basics

RDD (Resilient Distributed Dataset) is the low-level foundation of Spark. Prefer DataFrames for most use cases — they are more optimized. Use RDDs for unstructured data or custom partitioning logic.

sc = spark.sparkContext

# Create RDDs
rdd = sc.parallelize([1, 2, 3, 4, 5], numSlices=4)
rdd = sc.textFile("hdfs:///data/logs/*.log")
rdd = sc.wholeTextFiles("hdfs:///data/files/")  # (filename, content) pairs

# Transformations (lazy — create new RDD, not computed yet)
rdd.map(lambda x: x * 2)
rdd.flatMap(lambda line: line.split(" "))    # one element → zero or more
rdd.filter(lambda x: x > 3)
rdd.distinct()
rdd.sample(withReplacement=False, fraction=0.1)
rdd.union(rdd2)
rdd.intersection(rdd2)
rdd.subtract(rdd2)
rdd.sortBy(lambda x: x, ascending=False)
rdd.repartition(8)       # shuffle and repartition
rdd.coalesce(4)          # reduce partitions without shuffle

# Key-value transformations
pairs = rdd.map(lambda x: (x % 3, x))   # (key, value) RDD
pairs.groupByKey()
pairs.reduceByKey(lambda a, b: a + b)    # more efficient than groupByKey
pairs.sortByKey()
pairs.mapValues(lambda v: v * 2)
pairs.join(other_pairs)
pairs.leftOuterJoin(other_pairs)
pairs.cogroup(other_pairs)

# Actions (trigger computation)
rdd.collect()            # bring all data to driver (careful with large datasets)
rdd.count()
rdd.first()
rdd.take(10)
rdd.top(10)
rdd.sum()
rdd.min(), rdd.max()
rdd.mean(), rdd.variance()
rdd.reduce(lambda a, b: a + b)
rdd.countByValue()
rdd.saveAsTextFile("output/")
rdd.foreach(print)

Partitioning & Persistence

from pyspark import StorageLevel

# Check partitions
rdd.getNumPartitions()
df.rdd.getNumPartitions()

# Repartition (use when data is skewed)
df.repartition(200)                    # shuffle, creates even partitions
df.repartition(200, "user_id")         # partition by column (colocates same user_id)
df.coalesce(50)                        # reduce partitions (no shuffle)

# Persist / cache (avoid recomputing expensive transformations)
df.cache()                             # = persist(MEMORY_AND_DISK)
df.persist(StorageLevel.MEMORY_ONLY)
df.persist(StorageLevel.DISK_ONLY)
df.persist(StorageLevel.MEMORY_AND_DISK)

# Unpersist when done
df.unpersist()

# Custom partitioner (RDD only)
pairs.partitionBy(8, lambda key: hash(key) % 8)

# Check data distribution
df.groupBy(F.spark_partition_id()).count().show()  # records per partition

Broadcast Variables & Accumulators

# Broadcast: send large lookup table to all workers once (not per task)
country_map = spark.sparkContext.broadcast({
    "US": "United States", "DE": "Germany", "JP": "Japan"
})

df.withColumn("country_name",
    F.udf(lambda code: country_map.value.get(code, "Unknown"))(F.col("country_code"))
)

# Better: use DataFrame join with broadcast hint instead of UDF
country_df = spark.createDataFrame([("US", "United States"), ("DE", "Germany")],
    ["code", "name"])
df.join(F.broadcast(country_df), df.country_code == country_df.code, "left")

# Accumulator: collect metrics from workers
error_count = spark.sparkContext.accumulator(0)

def process(row):
    global error_count
    try:
        parse(row)
    except Exception:
        error_count += 1

rdd.foreach(process)
print(f"Errors: {error_count.value}")
Spark

Streaming, MLlib & Performance Tuning

Apache Spark: Streaming, MLlib & Performance Tuning Structured Streaming # Structured Streaming: same DataFrame API for real-time data # Read from Kafka stream_

Apache Spark: Streaming, MLlib & Performance Tuning

Structured Streaming

# Structured Streaming: same DataFrame API for real-time data

# Read from Kafka
stream_df = spark.readStream     .format("kafka")     .option("kafka.bootstrap.servers", "broker:9092")     .option("subscribe", "orders")     .option("startingOffsets", "latest")     .load()

# Parse JSON payload
from pyspark.sql.types import StructType, StringType, DoubleType
schema = StructType().add("order_id", StringType()).add("amount", DoubleType())

parsed = stream_df.select(
    F.from_json(F.col("value").cast("string"), schema).alias("data")
).select("data.*")

# Windowed aggregation (5-minute tumbling window)
windowed = parsed     .withWatermark("event_time", "10 minutes")     .groupBy(F.window("event_time", "5 minutes"))     .agg(F.sum("amount").alias("total_revenue"))

# Write to sink
query = windowed.writeStream     .outputMode("update")     .format("console")     .trigger(processingTime="30 seconds")     .start()

# Write to Kafka
query = parsed.select(
    F.col("order_id").cast("string").alias("key"),
    F.to_json(F.struct("*")).alias("value")
).writeStream     .format("kafka")     .option("kafka.bootstrap.servers", "broker:9092")     .option("topic", "processed-orders")     .option("checkpointLocation", "s3://bucket/checkpoints/")     .start()

query.awaitTermination()

MLlib

from pyspark.ml.feature import VectorAssembler, StandardScaler, StringIndexer
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.regression import LinearRegression, GBTRegressor
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# Feature engineering
indexer = StringIndexer(inputCol="category", outputCol="category_idx")
assembler = VectorAssembler(
    inputCols=["amount", "category_idx", "user_age"],
    outputCol="features"
)
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")

# Model
lr = LogisticRegression(featuresCol="scaled_features", labelCol="label",
    maxIter=100, regParam=0.01)

# Pipeline
pipeline = Pipeline(stages=[indexer, assembler, scaler, lr])

# Train/test split
train, test = df.randomSplit([0.8, 0.2], seed=42)
model = pipeline.fit(train)
predictions = model.transform(test)

# Evaluate
evaluator = BinaryClassificationEvaluator(labelCol="label")
auc = evaluator.evaluate(predictions)
print(f"AUC: {auc:.4f}")

# Cross-validation + hyperparameter tuning
paramGrid = ParamGridBuilder()     .addGrid(lr.regParam, [0.01, 0.1, 1.0])     .addGrid(lr.maxIter, [50, 100])     .build()

cv = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid,
    evaluator=evaluator, numFolds=5)
cv_model = cv.fit(train)

# Save/load
model.save("s3://bucket/models/churn-v1")
from pyspark.ml import PipelineModel
loaded = PipelineModel.load("s3://bucket/models/churn-v1")

Performance Tuning

  • Use Parquet or ORC — columnar formats enable predicate pushdown and column pruning.

  • Adaptive Query Execution (AQE): enable with spark.sql.adaptive.enabled=true — auto-optimizes joins and partitions at runtime.

  • Broadcast joins: use F.broadcast(small_df) for tables < 10MB; avoids shuffle. Threshold: spark.sql.autoBroadcastJoinThreshold.

  • Avoid UDFs: Python UDFs serialize/deserialize every row. Use built-in functions (F.*) or Pandas UDFs for vectorized execution.

  • Data skew: add random salt to skewed keys for aggregations; use salted join technique.

  • Partition pruning: always filter on partition columns (year, month) early; saves reading entire dataset.

  • spark.sql.shuffle.partitions: default 200 is too high for small jobs, too low for large. Rule: 2-3x number of cores.

  • Cache strategically: only cache DataFrames reused multiple times. Cache after filtering/aggregating, not before.

Keep your Spark knowledge sharp.

Save this stack to your personal DevRecall — add your own notes, track what you're learning, and share what you know with the community.

Get started — free forever