/* * IP packet filter */ #include "u.h" #include "../port/lib.h" #include "mem.h" #include "dat.h" #include "fns.h" #include "../port/error.h" #include "ip.h" #include "ipv6.h" typedef struct Ipmuxrock Ipmuxrock; typedef struct Ipmux Ipmux; enum { Tver, Tproto, Tdata, Tiph, Tdst, Tsrc, Tifc, }; /* * a node in the decision tree */ struct Ipmux { Ipmux *yes; Ipmux *no; uchar type; /* type of field(Txxxx) */ uchar len; /* length in bytes of item to compare */ uchar n; /* number of items val points to */ int off; /* offset of comparison */ uchar *val; uchar *mask; uchar *e; /* val+n*len*/ int ref; /* so we can garbage collect */ Conv *conv; }; /* * someplace to hold per conversation data */ struct Ipmuxrock { Ipmux *chain; }; static int ipmuxsprint(Ipmux*, int, char*, int); static void ipmuxkick(void *x); static void ipmuxfree(Ipmux *f); static char* skipwhite(char *p) { while(*p == ' ' || *p == '\t') p++; return p; } static char* follows(char *p, char c) { char *f; f = strchr(p, c); if(f == nil) return nil; *f++ = 0; f = skipwhite(f); if(*f == 0) return nil; return f; } static Ipmux* parseop(char **pp) { char *p = *pp; int type, off, end, len; Ipmux *f; p = skipwhite(p); if(strncmp(p, "ver", 3) == 0){ type = Tver; off = 0; len = 1; p += 3; } else if(strncmp(p, "dst", 3) == 0){ type = Tdst; off = offsetof(Ip6hdr, dst[0]); len = IPaddrlen; p += 3; } else if(strncmp(p, "src", 3) == 0){ type = Tsrc; off = offsetof(Ip6hdr, src[0]); len = IPaddrlen; p += 3; } else if(strncmp(p, "ifc", 3) == 0){ type = Tifc; off = -IPaddrlen; len = IPaddrlen; p += 3; } else if(strncmp(p, "proto", 5) == 0){ type = Tproto; off = offsetof(Ip6hdr, proto); len = 1; p += 5; } else if(strncmp(p, "data", 4) == 0 || strncmp(p, "iph", 3) == 0){ if(strncmp(p, "data", 4) == 0) { type = Tdata; p += 4; } else { type = Tiph; p += 3; } p = skipwhite(p); if(*p != '[') return nil; p++; off = strtoul(p, &p, 0); if(off < 0) return nil; p = skipwhite(p); if(*p != ':') end = off; else { p++; p = skipwhite(p); end = strtoul(p, &p, 0); if(end < off) return nil; p = skipwhite(p); } if(*p != ']') return nil; p++; len = end - off + 1; } else return nil; f = smalloc(sizeof(*f)); f->type = type; f->len = len; f->off = off; f->val = nil; f->mask = nil; f->n = 1; f->ref = 1; return f; } static int htoi(char x) { if(x >= '0' && x <= '9') x -= '0'; else if(x >= 'a' && x <= 'f') x -= 'a' - 10; else if(x >= 'A' && x <= 'F') x -= 'A' - 10; else x = 0; return x; } static int hextoi(char *p) { return (htoi(p[0])<<4) | htoi(p[1]); } static void parseval(uchar *v, char *p, int len) { while(*p && len-- > 0){ *v++ = hextoi(p); p += 2; } } static Ipmux* parsemux(char *p) { int n; Ipmux *f; char *val; char *mask; char *vals[20]; uchar *v; /* parse operand */ f = parseop(&p); if(f == nil) return nil; /* find value */ val = follows(p, '='); if(val == nil) goto parseerror; /* parse mask */ mask = follows(p, '&'); if(mask != nil){ switch(f->type){ case Tsrc: case Tdst: case Tifc: f->mask = smalloc(f->len); parseipmask(f->mask, mask, 0); break; case Tdata: case Tiph: f->mask = smalloc(f->len); parseval(f->mask, mask, f->len); break; default: goto parseerror; } } else if(f->type == Tver){ f->mask = smalloc(f->len); f->mask[0] = 0xF0; } /* parse vals */ f->n = getfields(val, vals, nelem(vals), 1, "|"); if(f->n == 0) goto parseerror; f->val = smalloc(f->n*f->len); v = f->val; for(n = 0; n < f->n; n++){ switch(f->type){ case Tver: if(f->n != 1) goto parseerror; if(strcmp(vals[n], "6") == 0) *v = IP_VER6; else if(strcmp(vals[n], "4") == 0) *v = IP_VER4; else goto parseerror; break; case Tsrc: case Tdst: case Tifc: if(parseip(v, vals[n]) == -1) goto parseerror; break; case Tproto: case Tdata: case Tiph: parseval(v, vals[n], f->len); break; } v += f->len; } f->e = f->val + f->n*f->len; return f; parseerror: ipmuxfree(f); return nil; } /* * Compare relative ordering of two ipmuxs. This doesn't compare the * values, just the fields being looked at. * * returns: <0 if a is a more specific match * 0 if a and b are matching on the same fields * >0 if b is a more specific match */ static int ipmuxcmp(Ipmux *a, Ipmux *b) { int n; /* compare types, lesser ones are more important */ n = a->type - b->type; if(n != 0) return n; /* compare offsets, call earlier ones more specific */ n = a->off - b->off; if(n != 0) return n; /* compare match lengths, longer ones are more specific */ n = b->len - a->len; if(n != 0) return n; /* * if we get here we have two entries matching * the same bytes of the record. Now check * the mask for equality. Longer masks are * more specific. */ if(a->mask != nil && b->mask == nil) return -1; if(a->mask == nil && b->mask != nil) return 1; if(a->mask != nil && b->mask != nil){ n = memcmp(b->mask, a->mask, a->len); if(n != 0) return n; } return 0; } /* * Compare the values of two ipmuxs. We're assuming that ipmuxcmp * returned 0 comparing them. */ static int ipmuxvalcmp(Ipmux *a, Ipmux *b) { int n; n = b->len*b->n - a->len*a->n; if(n != 0) return n; return memcmp(a->val, b->val, a->len*a->n); } /* * add onto an existing ipmux chain in the canonical comparison * order */ static void ipmuxchain(Ipmux **l, Ipmux *f) { for(; *l; l = &(*l)->yes) if(ipmuxcmp(f, *l) < 0) break; f->yes = *l; *l = f; } /* * copy a tree */ static Ipmux* ipmuxcopy(Ipmux *f) { Ipmux *nf; if(f == nil) return nil; nf = smalloc(sizeof *nf); *nf = *f; nf->no = ipmuxcopy(f->no); nf->yes = ipmuxcopy(f->yes); if(f->mask != nil){ nf->mask = smalloc(f->len); memmove(nf->mask, f->mask, f->len); } nf->val = smalloc(f->n*f->len); nf->e = nf->val + f->len*f->n; memmove(nf->val, f->val, f->n*f->len); return nf; } static void ipmuxfree(Ipmux *f) { if(f == nil) return; free(f->val); free(f->mask); free(f); } static void ipmuxtreefree(Ipmux *f) { if(f == nil) return; ipmuxfree(f->no); ipmuxfree(f->yes); ipmuxfree(f); } /* * merge two trees */ static Ipmux* ipmuxmerge(Ipmux *a, Ipmux *b) { int n; Ipmux *f; if(a == nil) return b; if(b == nil) return a; n = ipmuxcmp(a, b); if(n < 0){ f = ipmuxcopy(b); a->yes = ipmuxmerge(a->yes, b); a->no = ipmuxmerge(a->no, f); return a; } if(n > 0){ f = ipmuxcopy(a); b->yes = ipmuxmerge(b->yes, a); b->no = ipmuxmerge(b->no, f); return b; } if(ipmuxvalcmp(a, b) == 0){ a->yes = ipmuxmerge(a->yes, b->yes); a->no = ipmuxmerge(a->no, b->no); a->ref++; ipmuxfree(b); return a; } a->no = ipmuxmerge(a->no, b); return a; } /* * remove a chain from a demux tree. This is like merging accept that * we remove instead of insert. */ static int ipmuxremove(Ipmux **l, Ipmux *f) { int n, rv; Ipmux *ft; if(f == nil) return 0; /* we've removed it all */ if(*l == nil) return -1; ft = *l; n = ipmuxcmp(ft, f); if(n < 0){ /* *l is maching an earlier field, descend both paths */ rv = ipmuxremove(&ft->yes, f); rv += ipmuxremove(&ft->no, f); return rv; } if(n > 0){ /* f represents an earlier field than *l, this should be impossible */ return -1; } /* if we get here f and *l are comparing the same fields */ if(ipmuxvalcmp(ft, f) != 0){ /* different values mean mutually exclusive */ return ipmuxremove(&ft->no, f); } ipmuxremove(&ft->no, f->no); /* we found a match */ if(--(ft->ref) == 0){ /* * a dead node implies the whole yes side is also dead. * since our chain is constrained to be on that side, * we're done. */ ipmuxtreefree(ft->yes); *l = ft->no; ipmuxfree(ft); return 0; } /* * free the rest of the chain. it is constrained to match the * yes side. */ return ipmuxremove(&ft->yes, f->yes); } /* * convert to ipv4 filter */ static Ipmux* ipmuxconv4(Ipmux *f) { int i, n; if(f == nil) return nil; switch(f->type){ case Tproto: f->off = offsetof(Ip4hdr, proto); break; case Tdst: f->off = offsetof(Ip4hdr, dst[0]); if(0){ case Tsrc: f->off = offsetof(Ip4hdr, src[0]); } if(f->len != IPaddrlen) break; n = 0; for(i = 0; i < f->n; i++){ if(isv4(f->val + i*IPaddrlen)){ memmove(f->val + n*IPv4addrlen, f->val + i*IPaddrlen + IPv4off, IPv4addrlen); n++; } } if(n == 0){ ipmuxtreefree(f); return nil; } f->n = n; f->len = IPv4addrlen; if(f->mask != nil) memmove(f->mask, f->mask+IPv4off, IPv4addrlen); } f->e = f->val + f->n*f->len; f->yes = ipmuxconv4(f->yes); f->no = ipmuxconv4(f->no); return f; } /* * connection request is a semi separated list of filters * e.g. ver=4;proto=17;data[0:4]=11aa22bb;ifc=135.104.9.2&255.255.255.0 * * there's no protection against overlapping specs. */ static char* ipmuxconnect(Conv *c, char **argv, int argc) { int i, n; char *field[10]; Ipmux *mux, *chain; Ipmuxrock *r; Fs *f; f = c->p->f; if(argc != 2) return Ebadarg; n = getfields(argv[1], field, nelem(field), 1, ";"); if(n <= 0) return Ebadarg; chain = nil; mux = nil; for(i = 0; i < n; i++){ mux = parsemux(field[i]); if(mux == nil){ ipmuxtreefree(chain); return Ebadarg; } ipmuxchain(&chain, mux); } if(chain == nil) return Ebadarg; mux->conv = c; if(chain->type != Tver) { char ver6[] = "ver=6"; mux = parsemux(ver6); mux->yes = chain; mux->no = ipmuxcopy(chain); chain = mux; } if(*chain->val == IP_VER4) chain->yes = ipmuxconv4(chain->yes); else chain->no = ipmuxconv4(chain->no); /* save a copy of the chain so we can later remove it */ mux = ipmuxcopy(chain); r = (Ipmuxrock*)(c->ptcl); r->chain = chain; /* add the chain to the protocol demultiplexor tree */ wlock(f); f->ipmux->priv = ipmuxmerge(f->ipmux->priv, mux); wunlock(f); Fsconnected(c, nil); return nil; } static int ipmuxstate(Conv *c, char *state, int n) { Ipmuxrock *r; r = (Ipmuxrock*)(c->ptcl); return ipmuxsprint(r->chain, 0, state, n); } static void ipmuxcreate(Conv *c) { Ipmuxrock *r; c->rq = qopen(64*1024, Qmsg, 0, c); c->wq = qopen(64*1024, Qkick, ipmuxkick, c); r = (Ipmuxrock*)(c->ptcl); r->chain = nil; } static char* ipmuxannounce(Conv*, char**, int) { return "ipmux does not support announce"; } static void ipmuxclose(Conv *c) { Ipmuxrock *r; Fs *f = c->p->f; r = (Ipmuxrock*)(c->ptcl); qclose(c->rq); qclose(c->wq); qclose(c->eq); ipmove(c->laddr, IPnoaddr); ipmove(c->raddr, IPnoaddr); c->lport = 0; c->rport = 0; wlock(f); ipmuxremove(&(c->p->priv), r->chain); wunlock(f); ipmuxtreefree(r->chain); r->chain = nil; } /* * takes a fully formed ip packet and just passes it down * the stack */ static void ipmuxkick(void *x) { Conv *c = x; Block *bp; bp = qget(c->wq); if(bp != nil) { Ip4hdr *ih4 = (Ip4hdr*)(bp->rp); if((ih4->vihl & 0xF0) != IP_VER6) ipoput4(c->p->f, bp, 0, ih4->ttl, ih4->tos, nil); else ipoput6(c->p->f, bp, 0, ((Ip6hdr*)ih4)->ttl, 0, nil); } } static int maskmemcmp(uchar *m, uchar *v, uchar *c, int n) { int i; if(m == nil) return memcmp(v, c, n) != 0; for(i = 0; i < n; i++) if((v[i] & m[i]) != c[i]) return 1; return 0; } static void ipmuxiput(Proto *p, Ipifc *ifc, Block *bp) { Fs *f = p->f; Conv *c; Iplifc *lifc; Ipmux *mux; uchar *v; Ip4hdr *ip4; Ip6hdr *ip6; int off, hl; ip4 = (Ip4hdr*)bp->rp; if((ip4->vihl & 0xF0) == IP_VER4) { hl = (ip4->vihl&0x0F)<<2; ip6 = nil; } else { hl = IP6HDR; ip6 = (Ip6hdr*)ip4; } if(p->priv == nil) goto nomatch; c = nil; lifc = nil; /* run the filter */ rlock(f); mux = f->ipmux->priv; while(mux != nil){ switch(mux->type){ case Tifc: if(mux->len != IPaddrlen) goto no; for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next) for(v = mux->val; v < mux->e; v += IPaddrlen) if(maskmemcmp(mux->mask, lifc->local, v, IPaddrlen) == 0) goto yes; goto no; case Tdata: off = hl; break; default: off = 0; break; } off += mux->off; if(off < 0 || off + mux->len > BLEN(bp)) goto no; for(v = mux->val; v < mux->e; v += mux->len) if(maskmemcmp(mux->mask, bp->rp + off, v, mux->len) == 0) goto yes; no: mux = mux->no; continue; yes: if(mux->conv != nil) c = mux->conv; mux = mux->yes; } runlock(f); if(c != nil){ /* tack on interface address */ bp = padblock(bp, IPaddrlen); if(lifc == nil) lifc = ifc->lifc; ipmove(bp->rp, lifc != nil ? lifc->local : IPnoaddr); qpass(c->rq, concatblock(bp)); return; } nomatch: /* doesn't match any filter, hand it to the specific protocol handler */ if(ip6 != nil) p = f->t2p[ip6->proto]; else p = f->t2p[ip4->proto]; if(p != nil && p->rcv != nil){ (*p->rcv)(p, ifc, bp); return; } freeblist(bp); } static int ipmuxsprint(Ipmux *mux, int level, char *buf, int len) { int i, j, n; uchar *v; n = 0; for(i = 0; i < level; i++) n += snprint(buf+n, len-n, " "); if(mux == nil){ n += snprint(buf+n, len-n, "\n"); return n; } n += snprint(buf+n, len-n, "%s[%d:%d]", mux->type == Tdata ? "data": "iph", mux->off, mux->off+mux->len-1); if(mux->mask != nil){ n += snprint(buf+n, len-n, "&"); for(i = 0; i < mux->len; i++) n += snprint(buf+n, len - n, "%2.2ux", mux->mask[i]); } n += snprint(buf+n, len-n, "="); v = mux->val; for(j = 0; j < mux->n; j++){ for(i = 0; i < mux->len; i++) n += snprint(buf+n, len - n, "%2.2ux", *v++); n += snprint(buf+n, len-n, "|"); } n += snprint(buf+n, len-n, "\n"); level++; n += ipmuxsprint(mux->no, level, buf+n, len-n); n += ipmuxsprint(mux->yes, level, buf+n, len-n); return n; } static int ipmuxstats(Proto *p, char *buf, int len) { int n; Fs *f = p->f; rlock(f); n = ipmuxsprint(p->priv, 0, buf, len); runlock(f); return n; } extern void ipmuxinit(Fs *f) { Proto *ipmux; ipmux = smalloc(sizeof(Proto)); ipmux->priv = nil; ipmux->name = "ipmux"; ipmux->connect = ipmuxconnect; ipmux->announce = ipmuxannounce; ipmux->state = ipmuxstate; ipmux->create = ipmuxcreate; ipmux->close = ipmuxclose; ipmux->rcv = ipmuxiput; ipmux->ctl = nil; ipmux->advise = nil; ipmux->stats = ipmuxstats; ipmux->ipproto = -1; ipmux->nc = 64; ipmux->ptclsize = sizeof(Ipmuxrock); f->ipmux = ipmux; /* hack for Fsrcvpcol */ Fsproto(f, ipmux); }