Skip to content

Scala Closures

Closures are an important concept in functional programming. In Scala, a closure is a function that can access variables in the scope where it was defined, even when called outside that scope.

What is a Closure

A closure is a function value that references one or more variables outside its function body. These referenced variables are called "free variables".

Basic Closure Example

scala
object BasicClosure {
  def main(args: Array[String]): Unit = {
    var factor = 3

    // This is a closure that captures the external variable factor
    val multiplier = (x: Int) => x * factor

    println(multiplier(10))  // Output: 30

    // Modify the external variable
    factor = 5
    println(multiplier(10))  // Output: 50

    // Closures capture references to variables, not values
    var counter = 0
    val increment = () => {
      counter += 1
      counter
    }

    println(increment())  // 1
    println(increment())  // 2
    println(increment())  // 3
    println(s"Counter value: $counter")  // 3
  }
}

Closure vs Regular Function

scala
object ClosureVsFunction {
  def main(args: Array[String]): Unit = {
    // Regular function - does not depend on external variables
    def add(x: Int, y: Int): Int = x + y

    // Closure - depends on external variable
    val base = 10
    val addToBase = (x: Int) => x + base

    println(add(5, 3))        // 8
    println(addToBase(5))     // 15

    // Function literal (anonymous function)
    val square = (x: Int) => x * x
    println(square(4))        // 16

    // Closure with free variables
    var multiplier = 2
    val multiplyBy = (x: Int) => x * multiplier

    println(multiplyBy(5))    // 10
    multiplier = 3
    println(multiplyBy(5))    // 15
  }
}

Creating and Using Closures

Functions that Return Closures

scala
object ClosureFactory {
  // Function that returns a closure
  def makeAdder(increment: Int): Int => Int = {
    (x: Int) => x + increment
  }

  def makeMultiplier(factor: Int): Int => Int = {
    (x: Int) => x * factor
  }

  // More complex closure factory
  def makeCounter(start: Int = 0): () => Int = {
    var count = start
    () => {
      count += 1
      count
    }
  }

  def main(args: Array[String]): Unit = {
    // Create different adders
    val add5 = makeAdder(5)
    val add10 = makeAdder(10)

    println(add5(3))   // 8
    println(add10(3))  // 13

    // Create different multipliers
    val double = makeMultiplier(2)
    val triple = makeMultiplier(3)

    println(double(4))  // 8
    println(triple(4))  // 12

    // Create counters
    val counter1 = makeCounter()
    val counter2 = makeCounter(100)

    println(counter1())  // 1
    println(counter1())  // 2
    println(counter2())  // 101
    println(counter2())  // 102
  }
}

Closures as Parameters

scala
object ClosureAsParameter {
  // Function that accepts a closure as parameter
  def applyOperation(numbers: List[Int], operation: Int => Int): List[Int] = {
    numbers.map(operation)
  }

  def filterNumbers(numbers: List[Int], predicate: Int => Boolean): List[Int] = {
    numbers.filter(predicate)
  }

  // Higher-order function example
  def processData[T](data: List[T], processor: T => T, filter: T => Boolean): List[T] = {
    data.filter(filter).map(processor)
  }

  def main(args: Array[String]): Unit = {
    val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

    // Use closures for operations
    val threshold = 5
    val multiplier = 2

    // Create closures
    val isGreaterThanThreshold = (x: Int) => x > threshold
    val multiplyByFactor = (x: Int) => x * multiplier

    // Apply closures
    val filtered = filterNumbers(numbers, isGreaterThanThreshold)
    val transformed = applyOperation(numbers, multiplyByFactor)

    println(s"原始数据: $numbers")
    println(s"大于 $threshold 的数: $filtered")
    println(s"乘以 $multiplier: $transformed")

    // Combined operations
    val processed = processData(numbers, multiplyByFactor, isGreaterThanThreshold)
    println(s"先过滤再变换: $processed")
  }
}

Advanced Closure Usage

Currying and Closures

scala
object CurryingAndClosures {
  // Curried function
  def add(x: Int)(y: Int): Int = x + y

  // Implement currying with closures
  def addClosure(x: Int): Int => Int = (y: Int) => x + y

  // More complex currying example
  def calculate(operation: String)(x: Double)(y: Double): Double = {
    operation match {
      case "add" => x + y
      case "subtract" => x - y
      case "multiply" => x * y
      case "divide" => if (y != 0) x / y else throw new IllegalArgumentException("Division by zero")
      case _ => throw new IllegalArgumentException("Unknown operation")
    }
  }

  def main(args: Array[String]): Unit = {
    // Using curried functions
    val add5 = add(5) _  // Partial application
    println(add5(3))     // 8

    // Currying with closures
    val addClosure5 = addClosure(5)
    println(addClosure5(3))  // 8

    // Using complex currying
    val addOperation = calculate("add") _
    val multiplyOperation = calculate("multiply") _

    val add10 = addOperation(10) _
    val multiplyBy3 = multiplyOperation(3) _

    println(add10(5))        // 15
    println(multiplyBy3(4))  // 12
  }
}

Closures and Collection Operations

scala
object ClosuresWithCollections {
  def main(args: Array[String]): Unit = {
    val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    val words = List("scala", "java", "python", "javascript", "go")

    // Using closures for filtering
    val minLength = 4
    val longWords = words.filter(_.length >= minLength)
    println(s"长度至少为 $minLength 的单词: $longWords")

    // Using closures for transformation
    val prefix = "lang_"
    val prefixedWords = words.map(prefix + _)
    println(s"添加前缀: $prefixedWords")

    // Using closures for aggregation
    val threshold = 5
    val (small, large) = numbers.partition(_ < threshold)
    println(s"小于 $threshold: $small")
    println(s"大于等于 $threshold: $large")

    // Complex closure operations
    val multiplier = 2
    val filter = 3
    val result = numbers
      .filter(_ % filter == 0)
      .map(_ * multiplier)
      .reduce(_ + _)

    println(s"能被 $filter 整除的数乘以 $multiplier 后的和: $result")
  }
}

Closures and State Management

scala
object ClosureStateManagement {
  // Using closures to create private state
  def createBankAccount(initialBalance: Double): (String, Double) => Double = {
    var balance = initialBalance

    (operation: String, amount: Double) => {
      operation match {
        case "deposit" =>
          balance += amount
          balance
        case "withdraw" =>
          if (amount <= balance) {
            balance -= amount
            balance
          } else {
            throw new IllegalArgumentException("Insufficient funds")
          }
        case "balance" =>
          balance
        case _ =>
          throw new IllegalArgumentException("Unknown operation")
      }
    }
  }

  // Create counter factory
  def createCounter(): (String) => Int = {
    var count = 0

    (operation: String) => {
      operation match {
        case "increment" =>
          count += 1
          count
        case "decrement" =>
          count -= 1
          count
        case "reset" =>
          count = 0
          count
        case "get" =>
          count
        case _ =>
          throw new IllegalArgumentException("Unknown operation")
      }
    }
  }

  def main(args: Array[String]): Unit = {
    // Bank account example
    val account = createBankAccount(1000.0)

    println(s"初始余额: ${account("balance", 0)}")
    println(s"存款 500: ${account("deposit", 500)}")
    println(s"取款 200: ${account("withdraw", 200)}")
    println(s"当前余额: ${account("balance", 0)}")

    // Counter example
    val counter = createCounter()

    println(s"初始计数: ${counter("get")}")
    println(s"递增: ${counter("increment")}")
    println(s"递增: ${counter("increment")}")
    println(s"递减: ${counter("decrement")}")
    println(s"当前计数: ${counter("get")}")
    println(s"重置: ${counter("reset")}")
  }
}

Practical Applications of Closures

Event Handling

scala
object EventHandling {
  case class Event(name: String, data: Map[String, Any])

  // Event handler factory
  def createEventHandler(handlerName: String): Event => Unit = {
    var eventCount = 0

    (event: Event) => {
      eventCount += 1
      println(s"[$handlerName] 处理第 $eventCount 个事件: ${event.name}")
      event.data.foreach { case (key, value) =>
        println(s"  $key: $value")
      }
    }
  }

  // Conditional event handler
  def createConditionalHandler(condition: Event => Boolean, action: Event => Unit): Event => Unit = {
    (event: Event) => {
      if (condition(event)) {
        action(event)
      }
    }
  }

  def main(args: Array[String]): Unit = {
    val handler1 = createEventHandler("Handler1")
    val handler2 = createEventHandler("Handler2")

    val events = List(
      Event("UserLogin", Map("userId" -> 123, "timestamp" -> System.currentTimeMillis())),
      Event("UserLogout", Map("userId" -> 123, "duration" -> 3600)),
      Event("DataUpdate", Map("table" -> "users", "records" -> 5))
    )

    // Process events
    events.foreach(handler1)

    println("\n--- 条件处理器 ---")

    // Only process user-related events
    val userEventHandler = createConditionalHandler(
      _.name.startsWith("User"),
      event => println(s"用户事件: ${event.name}")
    )

    events.foreach(userEventHandler)
  }
}

Configuration and Factory Pattern

scala
object ConfigurationFactory {
  case class DatabaseConfig(host: String, port: Int, database: String)
  case class Connection(config: DatabaseConfig) {
    def query(sql: String): String = s"Executing '$sql' on ${config.host}:${config.port}/${config.database}"
  }

  // Configuration factory
  def createConnectionFactory(config: DatabaseConfig): () => Connection = {
    () => Connection(config)
  }

  // Factory with connection pool
  def createPooledConnectionFactory(config: DatabaseConfig, maxConnections: Int): () => Connection = {
    var connectionCount = 0

    () => {
      if (connectionCount < maxConnections) {
        connectionCount += 1
        println(s"创建新连接 (${connectionCount}/$maxConnections)")
        Connection(config)
      } else {
        println("连接池已满,重用现有连接")
        Connection(config)
      }
    }
  }

  // Operation with retry mechanism
  def createRetryableOperation[T](maxRetries: Int)(operation: () => T): () => Option[T] = {
    var attempts = 0

    () => {
      attempts += 1
      try {
        Some(operation())
      } catch {
        case _: Exception if attempts < maxRetries =>
          println(s"操作失败,重试 $attempts/$maxRetries")
          None
        case e: Exception =>
          println(s"操作最终失败: ${e.getMessage}")
          None
      }
    }
  }

  def main(args: Array[String]): Unit = {
    val config = DatabaseConfig("localhost", 5432, "myapp")

    // Simple factory
    val connectionFactory = createConnectionFactory(config)
    val conn1 = connectionFactory()
    println(conn1.query("SELECT * FROM users"))

    // Connection pool factory
    val pooledFactory = createPooledConnectionFactory(config, 2)
    val conn2 = pooledFactory()
    val conn3 = pooledFactory()
    val conn4 = pooledFactory()  // Should reuse connection

    // Retry operation
    val riskyOperation = () => {
      if (scala.util.Random.nextBoolean()) {
        "操作成功"
      } else {
        throw new RuntimeException("随机失败")
      }
    }

    val retryableOp = createRetryableOperation(3)(riskyOperation)
    val result = retryableOp()
    println(s"操作结果: $result")
  }
}

Closure Considerations

Memory Leak Risks

scala
object MemoryLeakExample {
  // Closures that can cause memory leaks
  def createLeakyClosures(): List[() => Int] = {
    val largeData = (1 to 1000000).toList  // Large amount of data

    // These closures all hold references to largeData
    (1 to 10).map { i =>
      () => largeData(i)  // Only uses one element, but holds reference to entire list
    }.toList
  }

  // Method to avoid memory leaks
  def createEfficientClosures(): List[() => Int] = {
    val largeData = (1 to 1000000).toList

    // Only capture needed data
    (1 to 10).map { i =>
      val value = largeData(i)  // Extract needed value
      () => value  // Closure only holds reference to this value
    }.toList
  }

  def main(args: Array[String]): Unit = {
    println("创建高效的闭包...")
    val efficientClosures = createEfficientClosures()
    println(s"第一个闭包的结果: ${efficientClosures.head()}")
  }
}

Variable Capture Timing

scala
object VariableCaptureTime {
  def main(args: Array[String]): Unit = {
    // Variable capture timing is important
    var functions = List[() => Int]()

    // Wrong way - all closures capture the same variable
    for (i <- 1 to 5) {
      var x = i
      functions = (() => x) :: functions
    }

    println("错误的捕获方式:")
    functions.reverse.foreach(f => print(s"${f()} "))
    println()

    // Correct way - each closure captures a different value
    functions = List[() => Int]()
    for (i <- 1 to 5) {
      val x = i  // Use val instead of var
      functions = (() => x) :: functions
    }

    println("正确的捕获方式:")
    functions.reverse.foreach(f => print(s"${f()} "))
    println()

    // Using functional approach is safer
    val functionalApproach = (1 to 5).map(i => () => i).toList
    println("函数式方法:")
    functionalApproach.foreach(f => print(s"${f()} "))
    println()
  }
}

Best Practices

  1. Minimize Captures: Only capture variables that the closure truly needs
  2. Use Immutable Variables: Prefer val over var
  3. Avoid Capturing Large Objects: If you only need part of an object, extract it first
  4. Be Aware of Variable Lifecycle: Ensure captured variables remain valid during closure usage
  5. Use Functional Methods: Leverage map, filter, and other higher-order functions instead of manually creating closures

Closures are one of the core features of Scala functional programming. Correctly understanding and using closures is crucial for writing high-quality Scala code.

Content is for learning and research only.