package html import ( "golang.org/x/net/html" "strings" ) type HtmlNode html.Node type NodeComparer func(n *HtmlNode) bool func (n *HtmlNode) GetAttribute(key string) ([]string, bool) { if n != nil { for _, attr := range n.Attr { if attr.Key == key { return strings.Split(attr.Val, " "), true } } } return []string{}, false } func (n *HtmlNode) HasAttributeVal(attr string, val string) NodeComparer { return func(n *HtmlNode) bool { return n.CheckAttribute(attr, val) } } func (n *HtmlNode) CheckAttribute(attr string, val string) bool { if n != nil && n.Type == html.ElementNode { attrvals, ok := n.GetAttribute(attr) if ok { for _, v := range attrvals { if v == val { return true } } } } return false } func (f NodeComparer) And(f2 NodeComparer) NodeComparer { return func(n *HtmlNode) bool { return f(n) && f2(n) } } func (n *HtmlNode) Find(f NodeComparer) *HtmlNode { if n != nil { if f(n) { return n } for c := (*HtmlNode)(n.FirstChild); c != nil; c = (*HtmlNode)(c.NextSibling) { result := c.Find(f) if result != nil { return result } } } return nil } func (n *HtmlNode) FindAll(f NodeComparer) []*HtmlNode { all := []*HtmlNode{} if n != nil { if f(n) { return []*HtmlNode{n} } for c := (*HtmlNode)(n.FirstChild); c != nil; c = (*HtmlNode)(c.NextSibling) { result := c.FindAll(f) if len(result) > 0 { all = append(all, result...) } } } return all } func IsText() NodeComparer { return func(n *HtmlNode) bool { return n.Type == html.TextNode } } func IsTag(name string) NodeComparer { return func(n *HtmlNode) bool { return n.Type == html.ElementNode && n.Data == name } } func HasClass(name string) NodeComparer { return func(n *HtmlNode) bool { return n.CheckAttribute("class", name) } } func HasID(id string) NodeComparer { return func(n *HtmlNode) bool { return n.CheckAttribute("id", id) } } func (n *HtmlNode) GetElementById(id string) *HtmlNode { if n != nil { return n.Find(HasID(id)) } else { return nil } } func (n *HtmlNode) GetElementsByClass(class string) []*HtmlNode { if n != nil { return n.FindAll(HasClass(class)) } else { return nil } } func (n *HtmlNode) Text() string { if n == nil { return "" } textNodes := n.FindAll(func(n *HtmlNode) bool { return n.Type == html.TextNode }) texts := []string{} for _, n := range textNodes { t := strings.TrimSpace(n.Data) if t != "" { texts = append(texts, t) } } if len(texts) > 0 { return strings.Join(texts, " ") } else { return "" } } func (n *HtmlNode) FirstText() string { text := n.Find(func(n *HtmlNode) bool { return n.Type == html.TextNode }) return strings.TrimSpace(text.Data) }