SAST专项之gosec项目源码学习(一)——扫描逻辑梳理

前言

gosec是一个公开源码的AST扫描器还结合了ai修复代码的前沿功能,是一个开发学习的好工具

securego/gosec:Go 安全检查员 — securego/gosec: Go security checker

image-20260208122820695

我们将通过查看源码的方式,临摹实现一个mini gosec包含从扫描节点到编写规则到最后的ai修复的核心功能,着重学习规则的编写和漏洞的识别和修复

这个项目是仅针对go语言项目的扫描用的库和我们之前用的go-tree-sitter也不一样,用的是原生的三个库

1
2
3
"go/ast"
"go/parser"
"go/token"

库的使用

go-tree-sitter一样,一个比较固定获取某个文件或者代码文本所有AST节点的固定写法如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package main

import (
"go/ast"
"go/parser"
"go/token"
)

func main() {
src := `
package main
import "fmt"
func main() {
fmt.Sprintf("SELECT * FROM users")
}
`
fset := token.NewFileSet()
f, _ := parser.ParseFile(fset, "", src, 0)

// 神奇的一行:打印出 AST 的结构
ast.Print(fset, f)
}

其中

1
fset := token.NewFileSet()

fset是用来建立一个文件集对象,这里的AST每个节点为了节省内存,某个函数某个变量名不直接存储在第几行第几列,只存储一个整数,需要fset把这个整数还原回它的行列数

接着是

1
node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
  • 第一个参数就是我们刚刚创建的fset,他用来记录解析过程中得到的文件信息
  • 第二个也就是go的文件路径
  • 第三个参数可以选择传入nil或者字符串,如果传入nil就去参数2提供的路径中读取文件内容,如果传字符串就不会去解析参数2的内容,直接把字符串传入的代码解析
  • 第四个参数就是解析模式,0代表只解析代码逻辑丢弃所有注释,parser.ParseComments代表保留所有注释,parser.Trace代表打印解析过程

上面我们使用了ast.print()直接把所有节点打印了出来,下面我们使用ast.Inspect()遍历所有ast节点

遍历扫描节点

ast.Inspect

使用ast.Inspect遍历是一个通过回调函数的方式来遍历的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package main

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)

func main() {
src := `
package main
import "fmt"
func main() {
fmt.Sprintf("SELECT * FROM users")
}
`
fset := token.NewFileSet()
node, _ := parser.ParseFile(fset, "", src, 0)

ast.Inspect(node, func(n ast.Node) bool {
fmt.Println(n)
return true
})
}

但是这样子做打印出来的是一些go的原始结构体,所以还是得先通过ast.Print查看数结构和具体字段名称再开始编写遍历查找节点的逻辑

ast.Walk

在我们看的项目gosec中的主程序使用的遍历方式并不是上面提到的inpect

那么这两有什么区别呢

  • ast.Inspect : 适合快速、轻量级的脚本或简单检查。你需要在一个大的闭包函数里写所有的逻辑。
  • ast.Visitor (ast.Walk) : 适合构建复杂的、面向对象的分析器。它允许你将状态(如 Context , Issues 列表)封装在结构体中,而不是依赖闭包捕获外部变量

这个demo把Walk用法和优势体现特比较明显

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package main

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)

// 1. 定义我们的 Visitor 结构体
// 它可以携带状态,比如当前的深度,或者我们收集到的信息
type MyVisitor struct {
depth int // 当前节点的深度,用于缩进打印
}

// 2. 实现 ast.Visitor 接口
// Visit 方法会被 ast.Walk 在每个节点上调用
func (v *MyVisitor) Visit(node ast.Node) ast.Visitor {
// node 为 nil 表示当前分支遍历结束,正在回溯
if node == nil {
return nil
}


// 打印当前节点类型
indent := strings.Repeat(" ", v.depth)
fmt.Printf("%s[%s] %+v\n", indent, reflect.TypeOf(node).Elem().Name(), node)

// 3. 返回一个新的 Visitor 用于遍历子节点
// 这里我们将深度 +1,这样子节点打印时就会缩进
// 注意:我们必须返回一个 Visitor,如果返回 nil,就不会继续遍历这个节点的子节点了
return &MyVisitor{
depth: v.depth + 1,
}
}

func main() {
// 模拟一段简单的 Go 代码
src := `
package main

func main() {
x := 42
println(x)
}
`
// 创建文件集
fset := token.NewFileSet()

// 解析代码为 AST
f, err := parser.ParseFile(fset, "", src, 0)
if err != nil {
panic(err)
}

fmt.Println("=== 开始遍历 AST ===")

// 4. 启动遍历
// ast.Walk 会从根节点 f 开始,递归调用 MyVisitor.Visit
visitor := &MyVisitor{depth: 0}
ast.Walk(visitor, f)

fmt.Println("=== 遍历结束 ===")
}

我们可以从gosec的visitor对象开始看起,看看它是怎么处理我们每一个节点的,这部分gosec写得不是很长

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
//记录信息的结构体
type astVisitor struct {
gosec *Analyzer
context *Context
issues []*issue.Issue
stats *Metrics
ignoreNosec bool
showIgnored bool
trackSuppressions bool
}

func (v *astVisitor) Visit(n ast.Node) ast.Visitor {
//专门判断接口类型的断言写法,这里只判断n的类型是不是ast.File
switch i := n.(type) {
case *ast.File:
//进入导入库检测流程
v.context.Imports.TrackFile(i)
}

//核心分发逻辑,RegisteredFor(n)根据当前节点的类型去RuleSet(规则集)里查找所有订阅了该类型节点的规则
for _, rule := range v.gosec.ruleset.RegisteredFor(n) {
//执行规则,调用具体规则的Match方法,把当前节点n和上下文v.context传给它
issue, err := rule.Match(n, v.context)
if err != nil {
//错误处理,记录日志
file, line := GetLocation(n, v.context)
file = path.Base(file)
v.gosec.logger.Printf("Rule error: %v => %s (%s:%d)\n", reflect.TypeOf(rule), err, file, line)
}
//收集结果
v.issues = v.gosec.updateIssues(issue, v.issues, v.stats, v.context.Ignores)
}
return v
}

可以看到,gosec这个项目是使用了一个订阅+匹配的逻辑去扫描漏洞,也就是先把可能出现漏洞的节点订阅之后再通过上下文匹配是否漏洞存在

节点订阅与规则载入

定位到rule.go,首先看它的结构体

1
2
3
4
5
6
7
type RuleSet struct {
//一个记录被订阅的AST节点对应一系列规则的映射关系表
Rules map[reflect.Type][]Rule

// 记录哪些规则被禁用了
RuleSuppressedMap map[string]bool
}

那么规则是怎么注册进去的呢,同样看到同个文件下的Register方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//属于结构体RuleSet的注册方法
func (r RuleSet) Register(rule Rule, isSuppressed bool, nodes ...ast.Node) {
for _, n := range nodes {
//获取节点类型
t := reflect.TypeOf(n)
//如果节点被订阅,也就是存在Rules的键中就append它
if rules, ok := r.Rules[t]; ok {
r.Rules[t] = append(rules, rule)
} else {
r.Rules[t] = []Rule{rule}
}
}
//处理禁止的逻辑
r.RuleSuppressedMap[rule.ID()] = isSuppressed
}

我们可以看看Register在哪里被调用了

analyze.go看到被调用

1
2
3
4
5
6
func (gosec *Analyzer) LoadRules(ruleDefinitions map[string]RuleBuilder, ruleSuppressed map[string]bool) {
for id, def := range ruleDefinitions {
r, nodes := def(id, gosec.config)
gosec.ruleset.Register(r, ruleSuppressed[id], nodes...)
}
}

我们继续往上找LoadRules在哪里被调用了

可以看到在程序入口main.go的run方法创建扫描器的时候被调用了

image-20260208165400081

我们看看传入的ruleList是怎么来的

在上面几行可以找到ruleList := loadRules(includeRules, excludeRules)

跟进loadRules方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
func loadRules(include, exclude string) rules.RuleList {
var filters []rules.RuleFilter
if include != "" {
logger.Printf("Including rules: %s", include)
including := strings.Split(include, ",")
filters = append(filters, rules.NewRuleFilter(false, including...))
} else {
logger.Println("Including rules: default")
}

if exclude != "" {
logger.Printf("Excluding rules: %s", exclude)
excluding := strings.Split(exclude, ",")
filters = append(filters, rules.NewRuleFilter(true, excluding...))
} else {
logger.Println("Excluding rules: default")
}
return rules.Generate(*flagTrackSuppressions, filters...)
}

这里经过一些过滤之后进入到Generate方法正式创建我们的规则

在这里就可以看到我们所有以G为开头命名的规则集了

image-20260208165834950

我们看看这个Generate是怎么运作的,如果我们想加入自定义订阅的节点和规则又应该怎么加

可以看到每条规则的最后是每条规则对应的一个函数,它们统一归类在rules文件夹的对应文件中

并且这些函数都返回了它们所订阅的节点信息

image-20260208171018329

所以如果我们想自己添加一个规则和它所订阅的节点

只需先在rules文件夹创建一个go文件,这个文件需要满足条件如下

  • 定义结构体 : 包含规则所需的元数据。
  • 实现 Match 方法 : 编写核心检测逻辑。
  • 实现构造函数 : 返回规则实例和你 想订阅的节点类型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// rules/my_rule.go

type MyRule struct {
issue.MetaData
}

func (r *MyRule) ID() string { return r.MetaData.ID }

func (r *MyRule) Match(n ast.Node, c *gosec.Context) (*issue.Issue, error) {
// 你的检测逻辑
return nil, nil
}

func NewMyRule(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
return &MyRule{...}, []ast.Node{
(*ast.FuncDecl)(nil), // 订阅函数声明节点
(*ast.CallExpr)(nil), // 订阅函数调用节点
}
}

接着在rulelist.go中添加你的规则信息和构造函数即可

1
2
3
4
5
6
7
// rules/rulelist.go
rules := []RuleDefinition{
// ... 现有规则 ...

// 添加你的一行
{"G901", "My custom rule description", NewMyRule},
}

漏洞规则匹配

上面我们已经完成了节点订阅和规则载入的探究,也可以自定义添加规则了,现在到了AST扫描的一个最重要的点,就是如何通过漏洞代码的上下文信息编写出精确的匹配规则去扫描存在的漏洞

我们依然回到walk的地方

1
2
3
4
5
6
7
8
9
10
11
12
13
func (v *astVisitor) Visit(n ast.Node) ast.Visitor {
//专门判断接口类型的断言写法,这里只判断n的类型是不是ast.File
switch i := n.(type) {
case *ast.File:
//进入导入库检测流程
v.context.Imports.TrackFile(i)
}

//核心分发逻辑,RegisteredFor(n)根据当前节点的类型去RuleSet(规则集)里查找所有订阅了该类型节点的规则
for _, rule := range v.gosec.ruleset.RegisteredFor(n) {
//执行规则,调用具体规则的Match方法,把当前节点n和上下文v.context传给它
issue, err := rule.Match(n, v.context)
...

根据上面就很清晰了,这里我们遍历取出的rule也就是我们在rules文件夹下go文件定义的结构体,我们现在来看看gosec原生的match方法是怎么写的来学习一下

AST树定位

在这之前要学习一下怎么去找到我们想要的节点,我们以一个赋值语句AssignStmt为例

1
err := db.QueryRow("SELECT * FROM users WHERE id=" + id).Scan(&u)

首先使用断言先拿到整个赋值表达式

1
2
3
4
if stmt, ok := n.(*ast.AssignStmt); ok {
ast.Print(fset, stmt)
}
return true

接着可以看到里面各种节点的关系和名称

image-20260208215612808

接着我们来定位这个赋值表达式左边的变量的名称,左右两边的表达式都是一个切片

1
2
3
4
5
if stmt, ok := n.(*ast.AssignStmt); ok {
if lname, ok := stmt.Lhs[0].(*ast.Ident); ok {
fmt.Print(lname.Name)
}
}

下面我们去找到赋值表达式右边的db变量以及它调用的QueryRowscan方法以及它们的参数

1
2
3
4
5
if stmt, ok := n.(*ast.AssignStmt); ok {
if l, ok := stmt.Rhs[0].(*ast.CallExpr).Fun.(*ast.SelectorExpr).X.(*ast.CallExpr).Fun.(*ast.SelectorExpr).X.(*ast.Ident); ok {
fmt.Printf(l.Name)
}
}

其他的只要跟着ast.print出的东西对照着一层一层写进去就可以了

贴一个类型转换表

image-20260208203557171

上下文CTX

上面提到在进行match()的时候不仅传入了订阅的节点,还传入了一个ctx,也就是上下文,那么它有什么用呢,他又是什么,首先回到我们的Visitor结构体,发现它其实就是这种结构体的第二个参数

1
2
3
4
5
6
7
8
9
type astVisitor struct {
gosec *Analyzer
context *Context
issues []*issue.Issue
stats *Metrics
ignoreNosec bool
showIgnored bool
trackSuppressions bool
}

它提供了规则分析所需的 全局视野 。如果没有 ctx ,你只有一个孤零零的 AST 节点 db ,你只知道它叫 “db”。
有了 ctx ,你可以查阅:

  • ctx.Info (types.Info) : 这一行代码里变量的类型是什么
  • ctx.Pkg (types.Package) : 当前是在哪个包里
  • ctx.Config : 用户有没有设置什么忽略规则

匹配检测具体逻辑

这里我们以一个sql注入的检测为例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error) {
switch stmt := n.(type) {

case *ast.AssignStmt:

for _, expr := range stmt.Rhs {

if call, ok := expr.(*ast.CallExpr); ok {

if sel, ok := call.Fun.(*ast.SelectorExpr); ok {

if sqlCall, ok := sel.X.(*ast.CallExpr); ok && s.ContainsCallExpr(sqlCall, ctx) != nil {
return s.checkQuery(sqlCall, ctx)
}
}
if s.ContainsCallExpr(expr, ctx) != nil {
return s.checkQuery(call, ctx)
}
}
}
case *ast.ExprStmt:
if call, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(call, ctx) != nil {
return s.checkQuery(call, ctx)
}
}
return nil, nil
}

这是整块的match方法,它订阅了两个节点,一个是AssignStmt一个是ExprStmt,我们先从AssignStmt的处理开始阅读

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error) {
switch stmt := n.(type) {
//赋值表达式节点被订阅
case *ast.AssignStmt:
// 遍历赋值语句的右值表达式 如果是a,b := fun1(),fun2(),右边会有两个调用表达式
for _, expr := range stmt.Rhs {
//进入函数表达式
if call, ok := expr.(*ast.CallExpr); ok {
//检查a.b().c()型的调用
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
//取出a.b()部分
//是因为考虑到有db.QueryRow(...).Scan(...)这样的调用
if sqlCall, ok := sel.X.(*ast.CallExpr); ok && s.ContainsCallExpr(sqlCall, ctx) != nil {
return s.checkQuery(sqlCall, ctx)
}
}
if s.ContainsCallExpr(expr, ctx) != nil {
return s.checkQuery(call, ctx)
}
}
}

我们可以看到它是把我们的一整个CallExpr和传入checkQueryContainsCallExpr之中,我们看看这两个函数做了什么,总的来说它们的逻辑如下

  • s.ContainsCallExpr : 查黑名单。判断 db.Query 这个函数名是否在我们的监控列表里。
  • s.checkQuery : 如果在名单里,进一步深入检查参数
    接下来讲解这两个部分

ContainsCallExpr包名函数名称检测

getCallInfo获取函数调用的具体信息

先是来到函数定义处

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
//获取调用信息,包括包名/类型名,函数名
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return nil
}

//查询包名/类型名,函数名是否在黑名单
if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) {
return nil
}

return n.(*ast.CallExpr)
}

下面我们跟进GetCallInfo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
//检查缓存,先略过
if ctx.callCache != nil {
if res, ok := ctx.callCache[n]; ok {
return res.packageName, res.funcName, res.err
}
}

//真正的解析调用信息的函数getCallInfo
packageName, funcName, err := getCallInfo(n, ctx)

//写入缓存
if ctx.callCache != nil {
ctx.callCache[n] = callInfo{packageName, funcName, err}
}
return packageName, funcName, err
}

下面我们来到真正干活的函数getCallInfo,代码比较长,一点点注释阅读

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
func getCallInfo(n ast.Node, ctx *Context) (string, string, error) {
//确保传入的 AST 节点是函数调用表达式 (CallExpr)
switch node := n.(type) {
case *ast.CallExpr:

//分析调用的主体部分 (Fun)
switch fn := node.Fun.(type) {

// Case 1: 方法调用或包函数调用 (SelectorExpr)
// 形式如: X.Sel(),例如: user.GetName() 或 fmt.Println()
case *ast.SelectorExpr:
switch expr := fn.X.(type) {

// Sub-Case 1.1: 接收者是一个标识符
// 例如: user.Get() 或 fmt.Print() 中的 "user" 或 "fmt"
case *ast.Ident:
// 检查该标识符是否在当前作用域被解析为一个变量 (Var)
// expr.Obj != nil: 变量已定义
// expr.Obj.Kind == ast.Var: 确认为变量,排除了常量、类型名或函数名
if expr.Obj != nil && expr.Obj.Kind == ast.Var {
// 利用类型检查器 (go/types) 获取该变量的实际类型
// 例如: var u User; u.Get() -> 获取 u 的类型 "User"
t := ctx.Info.TypeOf(expr)
if t != nil {
// 返回格式: (类型名, 方法名, nil)
return t.String(), fn.Sel.Name, nil
}
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info for variable: %s", expr.Name)
}

// 如果不是变量(expr.Obj 为 nil 或 Kind 不是 Var)
// 通常情况是:
// 1. 包名调用: fmt.Println() -> 返回 "fmt"
// 2. 类型转换后调用: MyType(v).Method() -> 返回 "MyType" (取决于 AST 结构)
// 3. 结构体/接口类型直接调用
return expr.Name, fn.Sel.Name, nil

// Sub-Case 1.2: 接收者也是一个选择器
// 例如: app.config.Start(),此处 expr 对应 "app.config"
case *ast.SelectorExpr:
// 确保中间字段存在 (例如 config)
if expr.Sel != nil {
// ctx从'config'查询,从上下文中找到app.config的类型
t := ctx.Info.TypeOf(expr.Sel)
if t != nil {
return t.String(), fn.Sel.Name, nil
}
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info for selector")
}

// Sub-Case 1.3: 接收者是另一个函数调用 (CallExpr),即链式调用
// 例如: NewUser().Login(),此处 expr 对应 "NewUser()"
case *ast.CallExpr:
// 分析链式调用的源头函数
switch call := expr.Fun.(type) {
case *ast.Ident:
// 特殊处理内建函数 new(): new(User).Login()
// 取出 new 的第一个参数 (类型) 作为调用方类型
if call.Name == "new" && len(expr.Args) > 0 {
t := ctx.Info.TypeOf(expr.Args[0])
if t != nil {
return t.String(), fn.Sel.Name, nil
}
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info for new()")
}

// 处理普通函数链式调用: GetUser().Login()
// 注意:此处代码尝试通过 AST 对象声明 (Obj.Decl) 手动查找返回值
if call.Obj != nil {
switch decl := call.Obj.Decl.(type) {
case *ast.FuncDecl:
// 获取函数定义的返回值列表
ret := decl.Type.Results
if ret != nil && len(ret.List) > 0 {
// 默认取第一个返回值作为链式调用的接收者
ret1 := ret.List[0]
if ret1 != nil {
t := ctx.Info.TypeOf(ret1.Type)
if t != nil {
return t.String(), fn.Sel.Name, nil
}
return "undefined", fn.Sel.Name, fmt.Errorf("missing return type info")
}
}
}
}
}
}

// Case 2: 普通函数调用 (Ident)
// 形式如: someFunc(),此时没有接收者对象
case *ast.Ident:
// 返回当前包名和函数名
return ctx.Pkg.Name(), fn.Name, nil
}
}

return "", "", fmt.Errorf("unable to determine call info")
}

Contains与ContainsPointer检测是否是漏洞规则函数

回到我们的大函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
//我们已经获取了函数调用的具体信息
selector, ident, err := GetCallInfo(n, ctx)
if err != nil {
return nil
}

//进入到call list中检查是否存在
if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) {
return nil
}

return n.(*ast.CallExpr)
}

可以知道这两个小检测函数都来自我们的CallList结构体,我们可以先看看这个结构体的结构

可以看到它是一个嵌套MAP

1
2
3
4
5
type set map[string]bool

// CallList is used to check for usage of specific packages
// and functions.
type CallList map[string]set

CallList结构体本身有一个ADD方法用来把规则塞进去,规则还是存在与rules/文件夹下各个文件夹中,在sql.go我们可以看见它的规则

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
var sqlCallIdents = map[string]map[string]int{
"*database/sql.Conn": {
"ExecContext": 1,
"QueryContext": 1,
"QueryRowContext": 1,
"PrepareContext": 1,
},
"*database/sql.DB": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
"*database/sql.Tx": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
}

第一个自然就是标准库名,第二个是函数名,第三个就是容易被注入的参数位置

而我们可以发现CallList第三个是bool型,因为通用层并不关心在第几个参数,只关心这个Call有没有在其中,当我们想要去获得具体的参数位置的时候再拿

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) {
typeName, fnName, err := gosec.GetCallInfo(call, ctx)
if err != nil {
return nil, err
}

if methods, ok := sqlCallIdents[typeName]; ok {
if i, ok := methods[fnName]; ok && i < len(call.Args) {
return call.Args[i], nil
}
}

return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
}
1
2
3
4
5
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error) {
query, err := findQueryArg(call, ctx)
if err != nil {
return nil, err
}

我们回到我们要看的两个小函数Contains

首先先看c.Contanis

1
2
3
4
5
6
7
func (c CallList) Contains(selector, ident string) bool {
if idents, ok := c[selector]; ok {
_, found := idents[ident]
return found
}
return false
}

这就是个很简单的查表逻辑

再看c.ContainsPointer

1
2
3
4
5
6
7
8
9
10
func (c CallList) ContainsPointer(selector, indent string) bool {
if strings.HasPrefix(selector, "*") {
if c.Contains(selector, indent) {
return true
}
s := strings.TrimPrefix(selector, "*")
return c.Contains(s, indent)
}
return false
}

这是一个模糊匹配机制,具体处理以下情况

规则定义假设有一个规则,它注册时写的是 非指针类型 :

1
2
// 规则里写的是 "bytes.Buffer"
rule.Add("bytes.Buffer", "WriteString")

但在用户的实际代码里,大家通常是传 Buffer 的 指针 来用的

1
2
3
4
func process(buf *bytes.Buffer) { // 注意这里是 
*bytes.Buffer
buf.WriteString("hello")
}

匹配过程

当 gosec 扫描到 buf.WriteString(“hello”) 时:

  1. GetCallInfo : ctx.Info.TypeOf(buf) 会告诉 gosec,这个变量 buf 的类型是 *bytes.Buffer

  2. Contains :
    先拿 *bytes.Buffer 去 Map 里查。

    • Map 里只有 bytes.Buffer
    • 匹配失败 。
  3. ContainsPointer

    • 发现 selector 是 *bytes.Buffer ,以 * 开头。
    • 尝试去掉 * ,变成 bytes.Buffer 。
    • 拿 bytes.Buffer 去 Map 里查。
    • 匹配成功!

如果从这个包名和参数名的层面确认了,这是一个危险的地方,那么就到下一步参数的检测了

checkQuery传入参数检测

这个过程也比较复杂,有100多行的代码

首先就是上面提到的获得参数

1
2
3
4
5
6
7
8
9
10
11
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error) {
//获得query之后再对query进行判断
query, err := findQueryArg(call, ctx)
if err != nil {
return nil, err
}

// Direct binary concatenation (e.g., "SELECT ..." + tainted)
if be, ok := query.(*ast.BinaryExpr); ok {
...
...

第一种情况,判断是不是db.query("select"+q)这样的语句,对于这种情况,gosec是这样匹配的

  • 先判断参数是不是一个字面量
  • 如果不是字面量,拆解参数
  • 判断开头字面量是否是sql模式
  • 判断后续变量是否可控,即不是常量

具体代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// Direct binary concatenation (e.g., "SELECT ..." + tainted)
//判断db.query("select"+q)的情况
//先用ast.BinaryExpr判断是不是二元表达式
if be, ok := query.(*ast.BinaryExpr); ok {
//拆解操作数,例如把"SELECT " + "id=" + val -> ["SELECT ", "id=", val]
operands := gosec.GetBinaryExprOperands(be)
//检查操作数是否以字符串字面量开头
if start, ok := operands[0].(*ast.BasicLit); ok {
//检查字符串字面量是否匹配SQL模式,使用的是正则匹配
/*
var (
sqlRegexp = regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)")
sqlFormatRegexp = regexp.MustCompile("%[^bdoxXfFp]")
)
*/
if str, e := gosec.GetString(start); e == nil && s.MatchPatterns(str) {
//获取操作数的剩余部分,例如"id=" + val -> ["id=", val]
for _, op := range operands[1:] {

// gosec.TryResolve 尝试解析变量的值,如果能解析出来那就认为是安全的
if gosec.TryResolve(op, ctx) {
continue
}
// 如果解析不了说明它是真正的变量,则爆出可能存在sql注入风险
return ctx.NewIssue(be, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
}
return nil, nil
}

第二种情况,为了解决db.query(q)这样的情况,它的匹配逻辑如下

  • 匹配是不是一个字面量,如果是”ss”这样写好的直接pass
  • 如果不是字面量,则不像上面的解析去看他是不是变量,而是把这个当做变量,然后去查阅他被声明的地方

这里使用FileSet去还原某个定义对象的位置信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
//获取q的定义对象,确保它是一个变量
v, ok := ctx.Info.ObjectOf(ident).(*types.Var)
if !ok {
return nil, nil
}

// Determine search scope (package-level or local)
//判断q是包级变量还是局部变量
//如果是局部变量,只需要在本文件中搜索定义
//如果是包级变量,需要在所有文件中搜索定义
isPkgLevel := ctx.Pkg != nil && v.Parent() == ctx.Pkg.Scope()

//根据q的定义位置,判断需要在哪些文件中搜索
var filesToSearch []*ast.File
if isPkgLevel {
filesToSearch = ctx.PkgFiles
} else {
callFile := gosec.ContainingFile(call, ctx)
if callFile == nil {
return nil, nil
}
filesToSearch = []*ast.File{callFile}
}

// Find the defining declaration and check for SQL patterns / initial risky concatenation
declRHS := []ast.Expr{}
foundDecl := false

// Determine the file containing the variable's defining position
//使用v.pos获取q定义的绝对位置后映射到文件信息
var declFile *ast.File
if ctx.FileSet != nil {
if posFile := ctx.FileSet.File(v.Pos()); posFile != nil {
targetName := posFile.Name()
for _, f := range filesToSearch {
if fileInfo := ctx.FileSet.File(f.Pos()); fileInfo != nil && fileInfo.Name() == targetName {
declFile = f
break
}
}
}
}

找到哪个声明变量的文件之后,我们就用ast.inspect去扫描这个文件,找到这个声明具体的表达式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
if declFile != nil {
ast.Inspect(declFile, func(n ast.Node) bool {
switch d := n.(type) {
// 处理短变量声明,例如:var q = "SELECT * FROM users"
case *ast.ValueSpec:
//遍历左边的变量名,因为可能有多个变量被赋值,例如:var q, q2 = "SELECT * FROM users", "SELECT * FROM users2"
for _, name := range d.Names {
//检查变量名是否和q匹配
if name.Pos() == v.Pos() && ctx.Info.ObjectOf(name) == v {
//取出它的初始值,例如:var q = "SELECT * FROM users" -> ["SELECT * FROM users"]
declRHS = d.Values
foundDecl = true
return false // Stop inspection
}
}
// 处理赋值语句
case *ast.AssignStmt:
//限制只处理短变量声明,例如:q := "SELECT * FROM users"
if d.Tok == token.DEFINE { // Only short variable declarations define new vars
//和上面的ValueSpec一样,遍历左边的变量名,检查是否和q匹配
for _, lhs := range d.Lhs {
if id, ok := lhs.(*ast.Ident); ok && id.Pos() == v.Pos() && ctx.Info.ObjectOf(id) == v {
declRHS = d.Rhs
foundDecl = true
return false // Stop inspection
}
}
}
}
return true
})
}

找到这个表达式并且取出它右边的值之后就可以做最后的判断了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
if foundDecl {
// Check for SQL patterns in initial values
//检查是否有sql模式
hasSQLPattern := false
for _, val := range declRHS {
if str, err := gosec.GetStringRecursive(val); err == nil && s.MatchPatterns(str) {
hasSQLPattern = true
break
}
}

// Check for risky initial concatenation
//检查初始值是否存在拼接,如果存在拼接,就认为是存在sql注入风险
if inj := s.findInjectionInBranch(ctx, declRHS); inj != nil {
return ctx.NewIssue(inj, s.ID(), s.What, s.Severity, s.Confidence), nil
}

//如果sql模式都没有,就不认为是sql注入点
if !hasSQLPattern {
return nil, nil
}
} else {
// No defining declaration found → assume not SQL-related
return nil, nil
}

// Check for risky mutations (query += tainted or query = query + tainted)
//验证变异,如果发现确实存在sql模式,但是目前表达式不存在拼接,检测之后是否存在拼接
for _, f := range filesToSearch {
var found *ast.AssignStmt
ast.Inspect(f, func(n ast.Node) bool {
//找出所有赋值语句
assign, ok := n.(*ast.AssignStmt)
if !ok || len(assign.Lhs) != 1 || len(assign.Rhs) != 1 {
return true
}
//确认左值是我们关注的那个q
lIdent, ok := assign.Lhs[0].(*ast.Ident)
if !ok || ctx.Info.ObjectOf(lIdent) != v {
return true
}

var appended ast.Expr
//检查两种追加方式
switch assign.Tok {
case token.ADD_ASSIGN:
//如果是+=操作,直接取右边的表达式
appended = assign.Rhs[0]
case token.ASSIGN:
//如果是=操作,需要检查右边是否是拼接操作
be, ok := assign.Rhs[0].(*ast.BinaryExpr)
if !ok || be.Op != token.ADD {
return true
}
left, ok := be.X.(*ast.Ident)
if !ok || ctx.Info.ObjectOf(left) != v {
return true
}
//如果确认右边是是加法操作,取右边的表达式
appended = be.Y
default:
return true
}

//最后判决,如果右边的表达式无法被解析为常量,则认为是变量拼接则直接告警
if !gosec.TryResolve(appended, ctx) {
found = assign
return false
}
return true
})
if found != nil {
return ctx.NewIssue(found, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}

return nil, nil

这就是最后的检查流程了,到这算完成了一整个sql注入的扫描了