goutil/html/html.go
2020-11-08 19:28:25 +01:00

135 lines
2.7 KiB
Go

package html
import (
"golang.org/x/net/html"
"strings"
)
type HtmlNode html.Node
type NodeComparer func(n *HtmlNode) bool
func (n *HtmlNode) GetAttribute(key string) ([]string, bool) {
if n != nil {
for _, attr := range n.Attr {
if attr.Key == key {
return strings.Split(attr.Val, " "), true
}
}
}
return []string{}, false
}
func (n *HtmlNode) HasAttributeVal(attr string, val string) NodeComparer {
return func(n *HtmlNode) bool {
return n.CheckAttribute(attr, val)
}
}
func (n *HtmlNode) CheckAttribute(attr string, val string) bool {
if n != nil && n.Type == html.ElementNode {
attrvals, ok := n.GetAttribute(attr)
if ok {
for _, v := range attrvals {
if v == val {
return true
}
}
}
}
return false
}
func (f NodeComparer) And(f2 NodeComparer) NodeComparer {
return func(n *HtmlNode) bool {
return f(n) && f2(n)
}
}
func (n *HtmlNode) Find(f NodeComparer) *HtmlNode {
if n != nil {
if f(n) {
return n
}
for c := (*HtmlNode)(n.FirstChild); c != nil; c = (*HtmlNode)(c.NextSibling) {
result := c.Find(f)
if result != nil {
return result
}
}
}
return nil
}
func (n *HtmlNode) FindAll(f NodeComparer) []*HtmlNode {
all := []*HtmlNode{}
if n != nil {
if f(n) {
return []*HtmlNode{n}
}
for c := (*HtmlNode)(n.FirstChild); c != nil; c = (*HtmlNode)(c.NextSibling) {
result := c.FindAll(f)
if len(result) > 0 {
all = append(all, result...)
}
}
}
return all
}
func IsText() NodeComparer {
return func(n *HtmlNode) bool { return n.Type == html.TextNode }
}
func IsTag(name string) NodeComparer {
return func(n *HtmlNode) bool { return n.Type == html.ElementNode && n.Data == name }
}
func HasClass(name string) NodeComparer {
return func(n *HtmlNode) bool { return n.CheckAttribute("class", name) }
}
func HasID(id string) NodeComparer {
return func(n *HtmlNode) bool { return n.CheckAttribute("id", id) }
}
func (n *HtmlNode) GetElementById(id string) *HtmlNode {
if n != nil {
return n.Find(HasID(id))
} else {
return nil
}
}
func (n *HtmlNode) GetElementsByClass(class string) []*HtmlNode {
if n != nil {
return n.FindAll(HasClass(class))
} else {
return nil
}
}
func (n *HtmlNode) Text() string {
if n == nil {
return ""
}
textNodes := n.FindAll(func(n *HtmlNode) bool { return n.Type == html.TextNode })
texts := []string{}
for _, n := range textNodes {
t := strings.TrimSpace(n.Data)
if t != "" {
texts = append(texts, t)
}
}
if len(texts) > 0 {
return strings.Join(texts, " ")
} else {
return ""
}
}
func (n *HtmlNode) FirstText() string {
text := n.Find(func(n *HtmlNode) bool { return n.Type == html.TextNode })
return strings.TrimSpace(text.Data)
}