diff --git a/html/html.go b/html/html.go
index deb592a..1a66a12 100644
--- a/html/html.go
+++ b/html/html.go
@@ -7,6 +7,8 @@ import (
type HtmlNode html.Node
+type NodeComparer func(n *HtmlNode) bool
+
func (n *HtmlNode) GetAttribute(key string) ([]string, bool) {
if n != nil {
for _, attr := range n.Attr {
@@ -18,6 +20,12 @@ func (n *HtmlNode) GetAttribute(key string) ([]string, bool) {
return []string{}, false
}
+func (n *HtmlNode) HasAttributeVal(attr string, val string) NodeComparer {
+ return func(n *HtmlNode) bool {
+ return n.CheckAttribute(attr, val)
+ }
+}
+
func (n *HtmlNode) CheckAttribute(attr string, val string) bool {
if n != nil && n.Type == html.ElementNode {
attrvals, ok := n.GetAttribute(attr)
@@ -32,7 +40,13 @@ func (n *HtmlNode) CheckAttribute(attr string, val string) bool {
return false
}
-func (n *HtmlNode) Find(f func(n *HtmlNode) bool) *HtmlNode {
+func (f NodeComparer) And(f2 NodeComparer) NodeComparer {
+ return func(n *HtmlNode) bool {
+ return f(n) && f2(n)
+ }
+}
+
+func (n *HtmlNode) Find(f NodeComparer) *HtmlNode {
if n != nil {
if f(n) {
return n
@@ -47,7 +61,7 @@ func (n *HtmlNode) Find(f func(n *HtmlNode) bool) *HtmlNode {
return nil
}
-func (n *HtmlNode) FindAll(f func(n *HtmlNode) bool) []*HtmlNode {
+func (n *HtmlNode) FindAll(f NodeComparer) []*HtmlNode {
all := []*HtmlNode{}
if n != nil {
if f(n) {
@@ -63,9 +77,25 @@ func (n *HtmlNode) FindAll(f func(n *HtmlNode) bool) []*HtmlNode {
return all
}
+func IsText() NodeComparer {
+ return func(n *HtmlNode) bool { return n.Type == html.TextNode }
+}
+
+func IsTag(name string) NodeComparer {
+ return func(n *HtmlNode) bool { return n.Type == html.ElementNode && n.Data == name }
+}
+
+func HasClass(name string) NodeComparer {
+ return func(n *HtmlNode) bool { return n.CheckAttribute("class", name) }
+}
+
+func HasID(id string) NodeComparer {
+ return func(n *HtmlNode) bool { return n.CheckAttribute("id", id) }
+}
+
func (n *HtmlNode) GetElementById(id string) *HtmlNode {
if n != nil {
- return n.Find(func(n *HtmlNode) bool { return n.CheckAttribute("id", id) })
+ return n.Find(HasID(id))
} else {
return nil
}
@@ -73,7 +103,7 @@ func (n *HtmlNode) GetElementById(id string) *HtmlNode {
func (n *HtmlNode) GetElementsByClass(class string) []*HtmlNode {
if n != nil {
- return n.FindAll(func(n *HtmlNode) bool { return n.CheckAttribute("class", class) })
+ return n.FindAll(HasClass(class))
} else {
return nil
}