Go Type Assertion

Type assertion is a mechanism in Go for checking the actual stored type of an interface variable. It allows extracting concrete type values from interface types and is an important tool for working with interfaces.

📋 Type Assertion Basics

Basic Syntax

package main

import "fmt"

func basicTypeAssertion() {
    fmt.Println("=== 基本类型断言 ===")
    
    // 空接口可以存储任何类型
    var data interface{} = "Hello, Go!"
    
    // 基本类型断言语法:value, ok := variable.(Type)
    if str, ok := data.(string); ok {
        fmt.Printf("断言成功:%s (长度: %d)\n", str, len(str))
    } else {
        fmt.Println("断言失败:不是字符串类型")
    }
    
    // 不安全的类型断言(可能panic)
    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("捕获panic: %v\n", r)
        }
    }()
    
    // 这会触发panic,因为data不是int类型
    // num := data.(int) // 取消注释会panic
    
    // 安全的类型断言
    if num, ok := data.(int); ok {
        fmt.Printf("是整数: %d\n", num)
    } else {
        fmt.Printf("不是整数类型,实际类型: %T\n", data)
    }
}

func multipleTypeAssertion() {
    fmt.Println("\n=== 多种类型处理 ===")
    
    values := []interface{}{
        42,
        "Hello",
        3.14159,
        true,
        []int{1, 2, 3},
        map[string]int{"count": 5},
    }
    
    for i, value := range values {
        fmt.Printf("值 %d (%T): ", i+1, value)
        
        // 尝试不同类型的断言
        if str, ok := value.(string); ok {
            fmt.Printf("字符串: \"%s\"\n", str)
        } else if num, ok := value.(int); ok {
            fmt.Printf("整数: %d\n", num)
        } else if f, ok := value.(float64); ok {
            fmt.Printf("浮点数: %.3f\n", f)
        } else if b, ok := value.(bool); ok {
            fmt.Printf("布尔值: %v\n", b)
        } else if slice, ok := value.([]int); ok {
            fmt.Printf("整数切片: %v\n", slice)
        } else if m, ok := value.(map[string]int); ok {
            fmt.Printf("映射: %v\n", m)
        } else {
            fmt.Printf("未知类型: %T\n", value)
        }
    }
}

func main() {
    basicTypeAssertion()
    multipleTypeAssertion()
}

Type Switch 语句

package main

import "fmt"

func typeSwitch() {
    fmt.Println("=== Type Switch 演示 ===")
    
    values := []interface{}{
        42,
        "Go Programming",
        3.14159,
        true,
        []string{"a", "b", "c"},
        map[string]interface{}{"name": "John", "age": 30},
        nil,
    }
    
    for i, value := range values {
        fmt.Printf("值 %d: ", i+1)
        
        switch v := value.(type) {
        case nil:
            fmt.Println("nil 值")
            
        case int:
            fmt.Printf("整数: %d", v)
            if v > 0 {
                fmt.Printf(" (正数)")
            } else if v < 0 {
                fmt.Printf(" (负数)")
            } else {
                fmt.Printf(" (零)")
            }
            fmt.Println()
            
        case string:
            fmt.Printf("字符串: \"%s\" (长度: %d)\n", v, len(v))
            
        case float64:
            fmt.Printf("浮点数: %.2f\n", v)
            
        case bool:
            if v {
                fmt.Println("布尔值: true")
            } else {
                fmt.Println("布尔值: false")
            }
            
        case []string:
            fmt.Printf("字符串切片: %v (元素数: %d)\n", v, len(v))
            
        case map[string]interface{}:
            fmt.Printf("映射: %v (键数: %d)\n", v, len(v))
            
        default:
            fmt.Printf("未处理的类型: %T, 值: %v\n", v, v)
        }
    }
}

func main() {
    typeSwitch()
}

🎯 Interface Type Assertion

Interface-to-Interface Assertion

package main

import "fmt"

// 定义多个接口
type Reader interface {
    Read() string
}

type Writer interface {
    Write(data string)
}

type ReadWriter interface {
    Reader
    Writer
}

type Closer interface {
    Close() error
}

// 实现类型
type File struct {
    name    string
    content string
    open    bool
}

func (f *File) Read() string {
    if !f.open {
        return ""
    }
    return f.content
}

func (f *File) Write(data string) {
    if f.open {
        f.content += data
    }
}

func (f *File) Close() error {
    f.open = false
    fmt.Printf("文件 %s 已关闭\n", f.name)
    return nil
}

func (f *File) String() string {
    return fmt.Sprintf("File{name: %s, open: %v}", f.name, f.open)
}

// 网络连接实现
type NetworkConnection struct {
    address string
    active  bool
}

func (nc *NetworkConnection) Read() string {
    if !nc.active {
        return ""
    }
    return "network data"
}

func (nc *NetworkConnection) Write(data string) {
    if nc.active {
        fmt.Printf("发送数据到 %s: %s\n", nc.address, data)
    }
}

func (nc *NetworkConnection) Close() error {
    nc.active = false
    fmt.Printf("网络连接 %s 已关闭\n", nc.address)
    return nil
}

func interfaceAssertion() {
    fmt.Println("=== 接口类型断言 ===")
    
    // 创建不同的实现
    file := &File{name: "test.txt", content: "初始内容", open: true}
    network := &NetworkConnection{address: "192.168.1.1", active: true}
    
    // 存储为不同的接口类型
    var readers []Reader = []Reader{file, network}
    
    for i, reader := range readers {
        fmt.Printf("\n--- Reader %d ---\n", i+1)
        
        // 读取数据
        data := reader.Read()
        fmt.Printf("读取数据: %s\n", data)
        
        // 检查是否也实现了 Writer 接口
        if writer, ok := reader.(Writer); ok {
            fmt.Println("✅ 也是 Writer,可以写入数据")
            writer.Write(" 新数据")
            
            // 再次读取以验证写入
            if newData := reader.Read(); newData != data {
                fmt.Printf("写入后数据: %s\n", newData)
            }
        } else {
            fmt.Println("❌ 不是 Writer")
        }
        
        // 检查是否实现了 Closer 接口
        if closer, ok := reader.(Closer); ok {
            fmt.Println("✅ 实现了 Closer,可以关闭")
            closer.Close()
        } else {
            fmt.Println("❌ 没有实现 Closer")
        }
        
        // 检查是否实现了 ReadWriter 接口
        if rw, ok := reader.(ReadWriter); ok {
            fmt.Println("✅ 实现了 ReadWriter")
            _ = rw // 避免未使用变量警告
        } else {
            fmt.Println("❌ 没有实现 ReadWriter")
        }
    }
}

func main() {
    interfaceAssertion()
}

🔍 Practical Applications of Type Assertion

JSON Data Processing

package main

import (
    "encoding/json"
    "fmt"
)

func jsonProcessing() {
    fmt.Println("=== JSON 数据处理 ===")
    
    // 模拟从 API 获取的 JSON 数据
    jsonData := `{
        "name": "张三",
        "age": 30,
        "active": true,
        "score": 95.5,
        "tags": ["developer", "golang"],
        "address": {
            "city": "北京",
            "zipcode": "100000"
        },
        "metadata": null
    }`
    
    // 解析为 map[string]interface{}
    var data map[string]interface{}
    err := json.Unmarshal([]byte(jsonData), &data)
    if err != nil {
        fmt.Printf("JSON 解析错误: %v\n", err)
        return
    }
    
    fmt.Println("解析 JSON 数据:")
    
    // 使用类型断言处理不同类型的值
    for key, value := range data {
        fmt.Printf("%s: ", key)
        
        switch v := value.(type) {
        case nil:
            fmt.Println("null")
            
        case string:
            fmt.Printf("字符串 \"%s\"\n", v)
            
        case float64: // JSON 数字都是 float64
            // 检查是否为整数
            if v == float64(int(v)) {
                fmt.Printf("整数 %d\n", int(v))
            } else {
                fmt.Printf("浮点数 %.1f\n", v)
            }
            
        case bool:
            fmt.Printf("布尔值 %v\n", v)
            
        case []interface{}:
            fmt.Printf("数组 [")
            for i, item := range v {
                if i > 0 {
                    fmt.Print(", ")
                }
                if str, ok := item.(string); ok {
                    fmt.Printf("\"%s\"", str)
                } else {
                    fmt.Printf("%v", item)
                }
            }
            fmt.Println("]")
            
        case map[string]interface{}:
            fmt.Printf("对象 {")
            first := true
            for k, val := range v {
                if !first {
                    fmt.Print(", ")
                }
                first = false
                fmt.Printf("%s: ", k)
                if str, ok := val.(string); ok {
                    fmt.Printf("\"%s\"", str)
                } else {
                    fmt.Printf("%v", val)
                }
            }
            fmt.Println("}")
            
        default:
            fmt.Printf("未知类型 %T: %v\n", v, v)
        }
    }
}

// 安全的类型断言辅助函数
func getString(data map[string]interface{}, key string) (string, bool) {
    if value, exists := data[key]; exists {
        if str, ok := value.(string); ok {
            return str, true
        }
    }
    return "", false
}

func getInt(data map[string]interface{}, key string) (int, bool) {
    if value, exists := data[key]; exists {
        if num, ok := value.(float64); ok {
            return int(num), true
        }
    }
    return 0, false
}

func getBool(data map[string]interface{}, key string) (bool, bool) {
    if value, exists := data[key]; exists {
        if b, ok := value.(bool); ok {
            return b, true
        }
    }
    return false, false
}

func safeJsonAccess() {
    fmt.Println("\n=== 安全的 JSON 数据访问 ===")
    
    jsonStr := `{"name": "李四", "age": 25, "active": true}`
    
    var data map[string]interface{}
    json.Unmarshal([]byte(jsonStr), &data)
    
    // 安全地获取各种类型的值
    if name, ok := getString(data, "name"); ok {
        fmt.Printf("姓名: %s\n", name)
    } else {
        fmt.Println("姓名字段不存在或不是字符串")
    }
    
    if age, ok := getInt(data, "age"); ok {
        fmt.Printf("年龄: %d\n", age)
    } else {
        fmt.Println("年龄字段不存在或不是数字")
    }
    
    if active, ok := getBool(data, "active"); ok {
        fmt.Printf("激活状态: %v\n", active)
    } else {
        fmt.Println("激活状态字段不存在或不是布尔值")
    }
    
    // 尝试获取不存在的字段
    if email, ok := getString(data, "email"); ok {
        fmt.Printf("邮箱: %s\n", email)
    } else {
        fmt.Println("邮箱字段不存在")
    }
}

func main() {
    jsonProcessing()
    safeJsonAccess()
}

Type Assertion in Error Handling

package main

import (
    "fmt"
    "net"
    "os"
    "syscall"
)

// 自定义错误类型
type ValidationError struct {
    Field   string
    Message string
}

func (ve ValidationError) Error() string {
    return fmt.Sprintf("验证错误 [%s]: %s", ve.Field, ve.Message)
}

type NetworkError struct {
    Op   string
    Host string
    Err  error
}

func (ne NetworkError) Error() string {
    return fmt.Sprintf("网络错误 [%s %s]: %v", ne.Op, ne.Host, ne.Err)
}

// 模拟可能返回不同错误类型的函数
func simulateOperations() []error {
    return []error{
        ValidationError{Field: "email", Message: "格式不正确"},
        NetworkError{Op: "dial", Host: "example.com", Err: fmt.Errorf("连接超时")},
        fmt.Errorf("通用错误"),
        &os.PathError{Op: "open", Path: "/nonexistent", Err: syscall.ENOENT},
        &net.OpError{Op: "dial", Net: "tcp", Addr: nil, Err: fmt.Errorf("网络不可达")},
    }
}

func errorTypeAssertion() {
    fmt.Println("=== 错误类型断言 ===")
    
    errors := simulateOperations()
    
    for i, err := range errors {
        fmt.Printf("\n错误 %d: %v\n", i+1, err)
        
        // 使用类型断言处理不同类型的错误
        switch e := err.(type) {
        case ValidationError:
            fmt.Printf("  类型: 验证错误\n")
            fmt.Printf("  字段: %s\n", e.Field)
            fmt.Printf("  消息: %s\n", e.Message)
            fmt.Printf("  建议: 请检查 %s 字段的格式\n", e.Field)
            
        case NetworkError:
            fmt.Printf("  类型: 网络错误\n")
            fmt.Printf("  操作: %s\n", e.Op)
            fmt.Printf("  主机: %s\n", e.Host)
            fmt.Printf("  建议: 检查网络连接和主机地址\n")
            
        case *os.PathError:
            fmt.Printf("  类型: 路径错误\n")
            fmt.Printf("  操作: %s\n", e.Op)
            fmt.Printf("  路径: %s\n", e.Path)
            fmt.Printf("  建议: 检查文件路径是否存在\n")
            
        case *net.OpError:
            fmt.Printf("  类型: 网络操作错误\n")
            fmt.Printf("  操作: %s\n", e.Op)
            fmt.Printf("  网络: %s\n", e.Net)
            fmt.Printf("  建议: 检查网络配置\n")
            
        default:
            fmt.Printf("  类型: 通用错误 (%T)\n", e)
            fmt.Printf("  建议: 查看错误消息获取更多信息\n")
        }
    }
}

// 错误包装和检查
func errorWrappingCheck() {
    fmt.Println("\n=== 错误包装检查 ===")
    
    // 创建包装错误
    baseErr := ValidationError{Field: "password", Message: "长度不够"}
    wrappedErr := fmt.Errorf("用户注册失败: %w", baseErr)
    
    fmt.Printf("包装错误: %v\n", wrappedErr)
    
    // 检查包装的错误类型
    var validationErr ValidationError
    if ok := fmt.Errorf(""); ok != nil { // 使用 errors.As 会更好
        // 手动检查
        if ve, ok := baseErr.(ValidationError); ok {
            fmt.Printf("找到验证错误: 字段=%s, 消息=%s\n", ve.Field, ve.Message)
        }
    }
    
    // 检查错误是否为特定类型
    checkErrorType := func(err error, typeName string) {
        switch err.(type) {
        case ValidationError:
            if typeName == "ValidationError" {
                fmt.Printf("✅ 错误类型匹配: %s\n", typeName)
            }
        case NetworkError:
            if typeName == "NetworkError" {
                fmt.Printf("✅ 错误类型匹配: %s\n", typeName)
            }
        default:
            fmt.Printf("❌ 错误类型不匹配: 期望 %s, 实际 %T\n", typeName, err)
        }
    }
    
    checkErrorType(baseErr, "ValidationError")
    checkErrorType(baseErr, "NetworkError")
}

func main() {
    errorTypeAssertion()
    errorWrappingCheck()
}

🎯 Best Practices

Safe Type Assertion Patterns

package main

import "fmt"

// 类型断言辅助函数
func assertString(i interface{}) (string, bool) {
    s, ok := i.(string)
    return s, ok
}

func assertInt(i interface{}) (int, bool) {
    n, ok := i.(int)
    return n, ok
}

func assertSlice(i interface{}) ([]interface{}, bool) {
    s, ok := i.([]interface{})
    return s, ok
}

// 通用类型检查函数
func checkType(value interface{}, expectedType string) bool {
    switch expectedType {
    case "string":
        _, ok := value.(string)
        return ok
    case "int":
        _, ok := value.(int)
        return ok
    case "float64":
        _, ok := value.(float64)
        return ok
    case "bool":
        _, ok := value.(bool)
        return ok
    case "slice":
        _, ok := value.([]interface{})
        return ok
    case "map":
        _, ok := value.(map[string]interface{})
        return ok
    default:
        return false
    }
}

// 类型转换器
type TypeConverter struct{}

func (tc TypeConverter) ToString(value interface{}) string {
    switch v := value.(type) {
    case string:
        return v
    case int:
        return fmt.Sprintf("%d", v)
    case float64:
        return fmt.Sprintf("%.2f", v)
    case bool:
        if v {
            return "true"
        }
        return "false"
    default:
        return fmt.Sprintf("%v", v)
    }
}

func (tc TypeConverter) ToInt(value interface{}) (int, error) {
    switch v := value.(type) {
    case int:
        return v, nil
    case float64:
        return int(v), nil
    case string:
        // 这里可以添加字符串到整数的转换逻辑
        return 0, fmt.Errorf("无法将字符串 '%s' 转换为整数", v)
    default:
        return 0, fmt.Errorf("无法将类型 %T 转换为整数", v)
    }
}

func bestPracticesDemo() {
    fmt.Println("=== 最佳实践演示 ===")
    
    data := []interface{}{
        "Hello",
        42,
        3.14,
        true,
        []interface{}{1, 2, 3},
        map[string]interface{}{"key": "value"},
    }
    
    converter := TypeConverter{}
    
    for i, value := range data {
        fmt.Printf("\n值 %d (%T): %v\n", i+1, value, value)
        
        // 安全的类型检查
        if checkType(value, "string") {
            fmt.Println("  ✅ 是字符串类型")
        }
        
        if checkType(value, "int") {
            fmt.Println("  ✅ 是整数类型")
        }
        
        // 类型转换
        strValue := converter.ToString(value)
        fmt.Printf("  转换为字符串: \"%s\"\n", strValue)
        
        if intValue, err := converter.ToInt(value); err == nil {
            fmt.Printf("  转换为整数: %d\n", intValue)
        } else {
            fmt.Printf("  整数转换失败: %v\n", err)
        }
        
        // 使用 type switch 的安全模式
        switch v := value.(type) {
        case string:
            fmt.Printf("  字符串长度: %d\n", len(v))
        case int:
            fmt.Printf("  整数平方: %d\n", v*v)
        case float64:
            fmt.Printf("  浮点数四舍五入: %.0f\n", v)
        case bool:
            fmt.Printf("  布尔值反转: %v\n", !v)
        case []interface{}:
            fmt.Printf("  切片长度: %d\n", len(v))
        case map[string]interface{}:
            fmt.Printf("  映射键数: %d\n", len(v))
        }
    }
}

func main() {
    bestPracticesDemo()
}

🎓 Summary

In this chapter, we comprehensively learned about Go type assertions:

  • Basic syntax: safe and unsafe type assertion forms
  • Type Switch: handling multiple types with switch statements
  • Interface assertion: interface-to-interface type assertions
  • Practical applications: usage in JSON processing and error handling
  • Best practices: safe type assertion patterns and helper functions

Type assertion is an important tool in Go for handling interfaces and dynamic types. Correct usage enables more flexible and safer code.


Next, we will learn about Go Inheritance to understand Go's design philosophy of composition over inheritance.

::: tip Type Assertion Tips

  • Always use the safe type assertion form (with ok return value)
  • Prefer type switch for handling multiple types
  • Create helper functions for commonly used type assertions
  • Pay special attention to type checking when processing external data :::