前言 gosec是一个公开源码的AST扫描器还结合了ai修复代码的前沿功能,是一个开发学习的好工具
securego/gosec:Go 安全检查员 — securego/gosec: Go security checker
我们将通过查看源码的方式,临摹实现一个mini gosec包含从扫描节点到编写规则到最后的ai修复的核心功能,着重学习规则的编写和漏洞的识别和修复
这个项目是仅针对go语言项目的扫描用的库和我们之前用的go-tree-sitter也不一样,用的是原生的三个库
1 2 3 "go/ast" "go/parser" "go/token"
库的使用 和go-tree-sitter一样,一个比较固定获取某个文件或者代码文本所有AST节点的固定写法如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 package mainimport ( "go/ast" "go/parser" "go/token" ) func main () { src := ` package main import "fmt" func main() { fmt.Sprintf("SELECT * FROM users") } ` fset := token.NewFileSet() f, _ := parser.ParseFile(fset, "" , src, 0 ) ast.Print(fset, f) }
其中
1 fset := token.NewFileSet()
fset是用来建立一个文件集对象,这里的AST每个节点为了节省内存,某个函数某个变量名不直接存储在第几行第几列,只存储一个整数,需要fset把这个整数还原回它的行列数
接着是
1 node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
第一个参数就是我们刚刚创建的fset,他用来记录解析过程中得到的文件信息 第二个也就是go的文件路径 第三个参数可以选择传入nil或者字符串,如果传入nil就去参数2提供的路径中读取文件内容,如果传字符串就不会去解析参数2的内容,直接把字符串传入的代码解析 第四个参数就是解析模式,0代表只解析代码逻辑丢弃所有注释,parser.ParseComments代表保留所有注释,parser.Trace代表打印解析过程 上面我们使用了ast.print()直接把所有节点打印了出来,下面我们使用ast.Inspect()遍历所有ast节点
遍历扫描节点 ast.Inspect 使用ast.Inspect遍历是一个通过回调函数的方式来遍历的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 package mainimport ( "fmt" "go/ast" "go/parser" "go/token" ) func main () { src := ` package main import "fmt" func main() { fmt.Sprintf("SELECT * FROM users") } ` fset := token.NewFileSet() node, _ := parser.ParseFile(fset, "" , src, 0 ) ast.Inspect(node, func (n ast.Node) bool { fmt.Println(n) return true }) }
但是这样子做打印出来的是一些go的原始结构体,所以还是得先通过ast.Print查看数结构和具体字段名称再开始编写遍历查找节点的逻辑
ast.Walk 在我们看的项目gosec中的主程序使用的遍历方式并不是上面提到的inpect
那么这两有什么区别呢
ast.Inspect : 适合快速、轻量级的脚本或简单检查。你需要在一个大的闭包函数里写所有的逻辑。 ast.Visitor (ast.Walk) : 适合构建复杂的、面向对象的分析器。它允许你将状态(如 Context , Issues 列表)封装在结构体中,而不是依赖闭包捕获外部变量 这个demo把Walk用法和优势体现特比较明显
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 package mainimport ( "fmt" "go/ast" "go/parser" "go/token" ) type MyVisitor struct { depth int } func (v *MyVisitor) Visit(node ast.Node) ast.Visitor { if node == nil { return nil } indent := strings.Repeat(" " , v.depth) fmt.Printf("%s[%s] %+v\n" , indent, reflect.TypeOf(node).Elem().Name(), node) return &MyVisitor{ depth: v.depth + 1 , } } func main () { src := ` package main func main() { x := 42 println(x) } ` fset := token.NewFileSet() f, err := parser.ParseFile(fset, "" , src, 0 ) if err != nil { panic (err) } fmt.Println("=== 开始遍历 AST ===" ) visitor := &MyVisitor{depth: 0 } ast.Walk(visitor, f) fmt.Println("=== 遍历结束 ===" ) }
我们可以从gosec的visitor对象开始看起,看看它是怎么处理我们每一个节点的,这部分gosec写得不是很长
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 type astVisitor struct { gosec *Analyzer context *Context issues []*issue.Issue stats *Metrics ignoreNosec bool showIgnored bool trackSuppressions bool } func (v *astVisitor) Visit(n ast.Node) ast.Visitor { switch i := n.(type ) { case *ast.File: v.context.Imports.TrackFile(i) } for _, rule := range v.gosec.ruleset.RegisteredFor(n) { issue, err := rule.Match(n, v.context) if err != nil { file, line := GetLocation(n, v.context) file = path.Base(file) v.gosec.logger.Printf("Rule error: %v => %s (%s:%d)\n" , reflect.TypeOf(rule), err, file, line) } v.issues = v.gosec.updateIssues(issue, v.issues, v.stats, v.context.Ignores) } return v }
可以看到,gosec这个项目是使用了一个订阅+匹配的逻辑去扫描漏洞,也就是先把可能出现漏洞的节点订阅之后再通过上下文匹配是否漏洞存在
节点订阅与规则载入 定位到rule.go,首先看它的结构体
1 2 3 4 5 6 7 type RuleSet struct { Rules map [reflect.Type][]Rule RuleSuppressedMap map [string ]bool }
那么规则是怎么注册进去的呢,同样看到同个文件下的Register方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 func (r RuleSet) Register(rule Rule, isSuppressed bool , nodes ...ast.Node) { for _, n := range nodes { t := reflect.TypeOf(n) if rules, ok := r.Rules[t]; ok { r.Rules[t] = append (rules, rule) } else { r.Rules[t] = []Rule{rule} } } r.RuleSuppressedMap[rule.ID()] = isSuppressed }
我们可以看看Register在哪里被调用了
在analyze.go看到被调用
1 2 3 4 5 6 func (gosec *Analyzer) LoadRules(ruleDefinitions map [string ]RuleBuilder, ruleSuppressed map [string ]bool ) { for id, def := range ruleDefinitions { r, nodes := def(id, gosec.config) gosec.ruleset.Register(r, ruleSuppressed[id], nodes...) } }
我们继续往上找LoadRules在哪里被调用了
可以看到在程序入口main.go的run方法创建扫描器的时候被调用了
我们看看传入的ruleList是怎么来的
在上面几行可以找到ruleList := loadRules(includeRules, excludeRules)
跟进loadRules方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 func loadRules (include, exclude string ) rules.RuleList { var filters []rules.RuleFilter if include != "" { logger.Printf("Including rules: %s" , include) including := strings.Split(include, "," ) filters = append (filters, rules.NewRuleFilter(false , including...)) } else { logger.Println("Including rules: default" ) } if exclude != "" { logger.Printf("Excluding rules: %s" , exclude) excluding := strings.Split(exclude, "," ) filters = append (filters, rules.NewRuleFilter(true , excluding...)) } else { logger.Println("Excluding rules: default" ) } return rules.Generate(*flagTrackSuppressions, filters...) }
这里经过一些过滤之后进入到Generate方法正式创建我们的规则
在这里就可以看到我们所有以G为开头命名的规则集了
我们看看这个Generate是怎么运作的,如果我们想加入自定义订阅的节点和规则又应该怎么加
可以看到每条规则的最后是每条规则对应的一个函数,它们统一归类在rules文件夹的对应文件中
并且这些函数都返回了它们所订阅的节点信息
所以如果我们想自己添加一个规则和它所订阅的节点
只需先在rules文件夹创建一个go文件,这个文件需要满足条件如下
定义结构体 : 包含规则所需的元数据。 实现 Match 方法 : 编写核心检测逻辑。 实现构造函数 : 返回规则实例和你 想订阅的节点类型 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 type MyRule struct { issue.MetaData } func (r *MyRule) ID() string { return r.MetaData.ID }func (r *MyRule) Match(n ast.Node, c *gosec.Context) (*issue.Issue, error ) { return nil , nil } func NewMyRule (id string , conf gosec.Config) (gosec.Rule, []ast.Node) { return &MyRule{...}, []ast.Node{ (*ast.FuncDecl)(nil ), (*ast.CallExpr)(nil ), } }
接着在rulelist.go中添加你的规则信息和构造函数即可
1 2 3 4 5 6 7 rules := []RuleDefinition{ {"G901" , "My custom rule description" , NewMyRule}, }
漏洞规则匹配 上面我们已经完成了节点订阅和规则载入的探究,也可以自定义添加规则了,现在到了AST扫描的一个最重要的点,就是如何通过漏洞代码的上下文信息编写出精确的匹配规则去扫描存在的漏洞
我们依然回到walk的地方
1 2 3 4 5 6 7 8 9 10 11 12 13 func (v *astVisitor) Visit(n ast.Node) ast.Visitor { switch i := n.(type ) { case *ast.File: v.context.Imports.TrackFile(i) } for _, rule := range v.gosec.ruleset.RegisteredFor(n) { issue, err := rule.Match(n, v.context) ...
根据上面就很清晰了,这里我们遍历取出的rule也就是我们在rules文件夹下go文件定义的结构体,我们现在来看看gosec原生的match方法是怎么写的来学习一下
AST树定位 在这之前要学习一下怎么去找到我们想要的节点,我们以一个赋值语句AssignStmt为例
1 err := db.QueryRow("SELECT * FROM users WHERE id=" + id).Scan(&u)
首先使用断言先拿到整个赋值表达式
1 2 3 4 if stmt, ok := n.(*ast.AssignStmt); ok { ast.Print(fset, stmt) } return true
接着可以看到里面各种节点的关系和名称
接着我们来定位这个赋值表达式左边的变量的名称,左右两边的表达式都是一个切片
1 2 3 4 5 if stmt, ok := n.(*ast.AssignStmt); ok { if lname, ok := stmt.Lhs[0 ].(*ast.Ident); ok { fmt.Print(lname.Name) } }
下面我们去找到赋值表达式右边的db变量以及它调用的QueryRow和scan方法以及它们的参数
1 2 3 4 5 if stmt, ok := n.(*ast.AssignStmt); ok { if l, ok := stmt.Rhs[0 ].(*ast.CallExpr).Fun.(*ast.SelectorExpr).X.(*ast.CallExpr).Fun.(*ast.SelectorExpr).X.(*ast.Ident); ok { fmt.Printf(l.Name) } }
其他的只要跟着ast.print出的东西对照着一层一层写进去就可以了
贴一个类型转换表
上下文CTX 上面提到在进行match()的时候不仅传入了订阅的节点,还传入了一个ctx,也就是上下文,那么它有什么用呢,他又是什么,首先回到我们的Visitor结构体,发现它其实就是这种结构体的第二个参数
1 2 3 4 5 6 7 8 9 type astVisitor struct { gosec *Analyzer context *Context issues []*issue.Issue stats *Metrics ignoreNosec bool showIgnored bool trackSuppressions bool }
它提供了规则分析所需的 全局视野 。如果没有 ctx ,你只有一个孤零零的 AST 节点 db ,你只知道它叫 “db”。 有了 ctx ,你可以查阅:
ctx.Info (types.Info) : 这一行代码里变量的类型是什么 ctx.Pkg (types.Package) : 当前是在哪个包里 ctx.Config : 用户有没有设置什么忽略规则 匹配检测具体逻辑 这里我们以一个sql注入的检测为例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error ) { switch stmt := n.(type ) { case *ast.AssignStmt: for _, expr := range stmt.Rhs { if call, ok := expr.(*ast.CallExpr); ok { if sel, ok := call.Fun.(*ast.SelectorExpr); ok { if sqlCall, ok := sel.X.(*ast.CallExpr); ok && s.ContainsCallExpr(sqlCall, ctx) != nil { return s.checkQuery(sqlCall, ctx) } } if s.ContainsCallExpr(expr, ctx) != nil { return s.checkQuery(call, ctx) } } } case *ast.ExprStmt: if call, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(call, ctx) != nil { return s.checkQuery(call, ctx) } } return nil , nil }
这是整块的match方法,它订阅了两个节点,一个是AssignStmt一个是ExprStmt,我们先从AssignStmt的处理开始阅读
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error ) { switch stmt := n.(type ) { case *ast.AssignStmt: for _, expr := range stmt.Rhs { if call, ok := expr.(*ast.CallExpr); ok { if sel, ok := call.Fun.(*ast.SelectorExpr); ok { if sqlCall, ok := sel.X.(*ast.CallExpr); ok && s.ContainsCallExpr(sqlCall, ctx) != nil { return s.checkQuery(sqlCall, ctx) } } if s.ContainsCallExpr(expr, ctx) != nil { return s.checkQuery(call, ctx) } } }
我们可以看到它是把我们的一整个CallExpr和传入checkQuery和ContainsCallExpr之中,我们看看这两个函数做了什么,总的来说它们的逻辑如下
s.ContainsCallExpr : 查黑名单。判断 db.Query 这个函数名是否在我们的监控列表里。 s.checkQuery : 如果在名单里,进一步深入检查参数 接下来讲解这两个部分 ContainsCallExpr包名函数名称检测 getCallInfo获取函数调用的具体信息 先是来到函数定义处
1 2 3 4 5 6 7 8 9 10 11 12 13 14 func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr { selector, ident, err := GetCallInfo(n, ctx) if err != nil { return nil } if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) { return nil } return n.(*ast.CallExpr) }
下面我们跟进GetCallInfo
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 func GetCallInfo (n ast.Node, ctx *Context) (string , string , error ) { if ctx.callCache != nil { if res, ok := ctx.callCache[n]; ok { return res.packageName, res.funcName, res.err } } packageName, funcName, err := getCallInfo(n, ctx) if ctx.callCache != nil { ctx.callCache[n] = callInfo{packageName, funcName, err} } return packageName, funcName, err }
下面我们来到真正干活的函数getCallInfo,代码比较长,一点点注释阅读
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 func getCallInfo (n ast.Node, ctx *Context) (string , string , error ) { switch node := n.(type ) { case *ast.CallExpr: switch fn := node.Fun.(type ) { case *ast.SelectorExpr: switch expr := fn.X.(type ) { case *ast.Ident: if expr.Obj != nil && expr.Obj.Kind == ast.Var { t := ctx.Info.TypeOf(expr) if t != nil { return t.String(), fn.Sel.Name, nil } return "undefined" , fn.Sel.Name, fmt.Errorf("missing type info for variable: %s" , expr.Name) } return expr.Name, fn.Sel.Name, nil case *ast.SelectorExpr: if expr.Sel != nil { t := ctx.Info.TypeOf(expr.Sel) if t != nil { return t.String(), fn.Sel.Name, nil } return "undefined" , fn.Sel.Name, fmt.Errorf("missing type info for selector" ) } case *ast.CallExpr: switch call := expr.Fun.(type ) { case *ast.Ident: if call.Name == "new" && len (expr.Args) > 0 { t := ctx.Info.TypeOf(expr.Args[0 ]) if t != nil { return t.String(), fn.Sel.Name, nil } return "undefined" , fn.Sel.Name, fmt.Errorf("missing type info for new()" ) } if call.Obj != nil { switch decl := call.Obj.Decl.(type ) { case *ast.FuncDecl: ret := decl.Type.Results if ret != nil && len (ret.List) > 0 { ret1 := ret.List[0 ] if ret1 != nil { t := ctx.Info.TypeOf(ret1.Type) if t != nil { return t.String(), fn.Sel.Name, nil } return "undefined" , fn.Sel.Name, fmt.Errorf("missing return type info" ) } } } } } } case *ast.Ident: return ctx.Pkg.Name(), fn.Name, nil } } return "" , "" , fmt.Errorf("unable to determine call info" ) }
Contains与ContainsPointer检测是否是漏洞规则函数 回到我们的大函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr { selector, ident, err := GetCallInfo(n, ctx) if err != nil { return nil } if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) { return nil } return n.(*ast.CallExpr) }
可以知道这两个小检测函数都来自我们的CallList结构体,我们可以先看看这个结构体的结构
可以看到它是一个嵌套MAP
1 2 3 4 5 type set map [string ]bool type CallList map [string ]set
CallList结构体本身有一个ADD方法用来把规则塞进去,规则还是存在与rules/文件夹下各个文件夹中,在sql.go我们可以看见它的规则
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 var sqlCallIdents = map [string ]map [string ]int { "*database/sql.Conn" : { "ExecContext" : 1 , "QueryContext" : 1 , "QueryRowContext" : 1 , "PrepareContext" : 1 , }, "*database/sql.DB" : { "Exec" : 0 , "ExecContext" : 1 , "Query" : 0 , "QueryContext" : 1 , "QueryRow" : 0 , "QueryRowContext" : 1 , "Prepare" : 0 , "PrepareContext" : 1 , }, "*database/sql.Tx" : { "Exec" : 0 , "ExecContext" : 1 , "Query" : 0 , "QueryContext" : 1 , "QueryRow" : 0 , "QueryRowContext" : 1 , "Prepare" : 0 , "PrepareContext" : 1 , }, }
第一个自然就是标准库名,第二个是函数名,第三个就是容易被注入的参数位置
而我们可以发现CallList第三个是bool型,因为通用层并不关心在第几个参数,只关心这个Call有没有在其中,当我们想要去获得具体的参数位置的时候再拿
1 2 3 4 5 6 7 8 9 10 11 12 13 14 func findQueryArg (call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error ) { typeName, fnName, err := gosec.GetCallInfo(call, ctx) if err != nil { return nil , err } if methods, ok := sqlCallIdents[typeName]; ok { if i, ok := methods[fnName]; ok && i < len (call.Args) { return call.Args[i], nil } } return nil , fmt.Errorf("SQL argument index not found for %s.%s" , typeName, fnName) }
1 2 3 4 5 func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error ) { query, err := findQueryArg(call, ctx) if err != nil { return nil , err }
我们回到我们要看的两个小函数Contains
首先先看c.Contanis
1 2 3 4 5 6 7 func (c CallList) Contains(selector, ident string ) bool { if idents, ok := c[selector]; ok { _, found := idents[ident] return found } return false }
这就是个很简单的查表逻辑
再看c.ContainsPointer
1 2 3 4 5 6 7 8 9 10 func (c CallList) ContainsPointer(selector, indent string ) bool { if strings.HasPrefix(selector, "*" ) { if c.Contains(selector, indent) { return true } s := strings.TrimPrefix(selector, "*" ) return c.Contains(s, indent) } return false }
这是一个模糊匹配机制,具体处理以下情况
规则定义假设有一个规则,它注册时写的是 非指针类型 :
1 2 rule.Add("bytes.Buffer" , "WriteString" )
但在用户的实际代码里,大家通常是传 Buffer 的 指针 来用的
1 2 3 4 func process (buf *bytes.Buffer) { *bytes.Buffer buf.WriteString("hello" ) }
匹配过程
当 gosec 扫描到 buf.WriteString(“hello”) 时:
GetCallInfo : ctx.Info.TypeOf(buf) 会告诉 gosec,这个变量 buf 的类型是 *bytes.Buffer
Contains : 先拿 *bytes.Buffer 去 Map 里查。
Map 里只有 bytes.Buffer 匹配失败 。 ContainsPointer
发现 selector 是 *bytes.Buffer ,以 * 开头。 尝试去掉 * ,变成 bytes.Buffer 。 拿 bytes.Buffer 去 Map 里查。 匹配成功! 如果从这个包名和参数名的层面确认了,这是一个危险的地方,那么就到下一步参数的检测了
checkQuery传入参数检测 这个过程也比较复杂,有100多行的代码
首先就是上面提到的获得参数
1 2 3 4 5 6 7 8 9 10 11 func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error ) { query, err := findQueryArg(call, ctx) if err != nil { return nil , err } if be, ok := query.(*ast.BinaryExpr); ok { ... ...
第一种情况,判断是不是db.query("select"+q)这样的语句,对于这种情况,gosec是这样匹配的
先判断参数是不是一个字面量 如果不是字面量,拆解参数 判断开头字面量是否是sql模式 判断后续变量是否可控,即不是常量 具体代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 if be, ok := query.(*ast.BinaryExpr); ok { operands := gosec.GetBinaryExprOperands(be) if start, ok := operands[0 ].(*ast.BasicLit); ok { if str, e := gosec.GetString(start); e == nil && s.MatchPatterns(str) { for _, op := range operands[1 :] { if gosec.TryResolve(op, ctx) { continue } return ctx.NewIssue(be, s.ID(), s.What, s.Severity, s.Confidence), nil } } } return nil , nil }
第二种情况,为了解决db.query(q)这样的情况,它的匹配逻辑如下
匹配是不是一个字面量,如果是”ss”这样写好的直接pass 如果不是字面量,则不像上面的解析去看他是不是变量,而是把这个当做变量,然后去查阅他被声明的地方 这里使用FileSet去还原某个定义对象的位置信息
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 v, ok := ctx.Info.ObjectOf(ident).(*types.Var) if !ok { return nil , nil } isPkgLevel := ctx.Pkg != nil && v.Parent() == ctx.Pkg.Scope() var filesToSearch []*ast.Fileif isPkgLevel { filesToSearch = ctx.PkgFiles } else { callFile := gosec.ContainingFile(call, ctx) if callFile == nil { return nil , nil } filesToSearch = []*ast.File{callFile} } declRHS := []ast.Expr{} foundDecl := false var declFile *ast.Fileif ctx.FileSet != nil { if posFile := ctx.FileSet.File(v.Pos()); posFile != nil { targetName := posFile.Name() for _, f := range filesToSearch { if fileInfo := ctx.FileSet.File(f.Pos()); fileInfo != nil && fileInfo.Name() == targetName { declFile = f break } } } }
找到哪个声明变量的文件之后,我们就用ast.inspect去扫描这个文件,找到这个声明具体的表达式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 if declFile != nil { ast.Inspect(declFile, func (n ast.Node) bool { switch d := n.(type ) { case *ast.ValueSpec: for _, name := range d.Names { if name.Pos() == v.Pos() && ctx.Info.ObjectOf(name) == v { declRHS = d.Values foundDecl = true return false } } case *ast.AssignStmt: if d.Tok == token.DEFINE { for _, lhs := range d.Lhs { if id, ok := lhs.(*ast.Ident); ok && id.Pos() == v.Pos() && ctx.Info.ObjectOf(id) == v { declRHS = d.Rhs foundDecl = true return false } } } } return true }) }
找到这个表达式并且取出它右边的值之后就可以做最后的判断了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 if foundDecl { hasSQLPattern := false for _, val := range declRHS { if str, err := gosec.GetStringRecursive(val); err == nil && s.MatchPatterns(str) { hasSQLPattern = true break } } if inj := s.findInjectionInBranch(ctx, declRHS); inj != nil { return ctx.NewIssue(inj, s.ID(), s.What, s.Severity, s.Confidence), nil } if !hasSQLPattern { return nil , nil } } else { return nil , nil } for _, f := range filesToSearch { var found *ast.AssignStmt ast.Inspect(f, func (n ast.Node) bool { assign, ok := n.(*ast.AssignStmt) if !ok || len (assign.Lhs) != 1 || len (assign.Rhs) != 1 { return true } lIdent, ok := assign.Lhs[0 ].(*ast.Ident) if !ok || ctx.Info.ObjectOf(lIdent) != v { return true } var appended ast.Expr switch assign.Tok { case token.ADD_ASSIGN: appended = assign.Rhs[0 ] case token.ASSIGN: be, ok := assign.Rhs[0 ].(*ast.BinaryExpr) if !ok || be.Op != token.ADD { return true } left, ok := be.X.(*ast.Ident) if !ok || ctx.Info.ObjectOf(left) != v { return true } appended = be.Y default : return true } if !gosec.TryResolve(appended, ctx) { found = assign return false } return true }) if found != nil { return ctx.NewIssue(found, s.ID(), s.What, s.Severity, s.Confidence), nil } } return nil , nil
这就是最后的检查流程了,到这算完成了一整个sql注入的扫描了