package html import ( "golang.org/x/net/html" "strings" ) type HtmlNode html.Node func (n *HtmlNode) GetAttribute(key string) ([]string, bool) { if n != nil { for _, attr := range n.Attr { if attr.Key == key { return strings.Split(attr.Val, " "), true } } } return []string{}, false } func (n *HtmlNode) CheckAttribute(attr string, val string) bool { if n != nil && n.Type == html.ElementNode { attrvals, ok := n.GetAttribute(attr) if ok { for _, v := range attrvals { if v == val { return true } } } } return false } func (n *HtmlNode) Find(f func(n *HtmlNode) bool) *HtmlNode { if n != nil { if f(n) { return n } for c := (*HtmlNode)(n.FirstChild); c != nil; c = (*HtmlNode)(c.NextSibling) { result := c.Find(f) if result != nil { return result } } } return nil } func (n *HtmlNode) FindAll(f func(n *HtmlNode) bool) []*HtmlNode { all := []*HtmlNode{} if n != nil { if f(n) { return []*HtmlNode{n} } for c := (*HtmlNode)(n.FirstChild); c != nil; c = (*HtmlNode)(c.NextSibling) { result := c.FindAll(f) if len(result) > 0 { all = append(all, result...) } } } return all } func (n *HtmlNode) GetElementById(id string) *HtmlNode { if n != nil { return n.Find(func(n *HtmlNode) bool { return n.CheckAttribute("id", id) }) } else { return nil } } func (n *HtmlNode) GetElementsByClass(class string) []*HtmlNode { if n != nil { return n.FindAll(func(n *HtmlNode) bool { return n.CheckAttribute("class", class) }) } else { return nil } } func (n *HtmlNode) Text() string { if n == nil { return "" } textNodes := n.FindAll(func(n *HtmlNode) bool { return n.Type == html.TextNode }) texts := []string{} for _, n := range textNodes { t := strings.TrimSpace(n.Data) if t != "" { texts = append(texts, t) } } if len(texts) > 0 { return strings.Join(texts, " ") } else { return "" } } func (n *HtmlNode) FirstText() string { text := n.Find(func(n *HtmlNode) bool { return n.Type == html.TextNode }) return strings.TrimSpace(text.Data) }