/*
 * See Aho, Corasick Comm ACM 18:6, and Hoffman and O'Donell JACM 29:1
 * for detail of the matching algorithm
 */

#ifndef va_start
#include <stdarg.h>
#endif
/* from libc.h */
extern	void*	malloc(long);
extern	void	free(void*);
extern	void	abort(void);
#define _twig_assert(a,str) \
	do{if(!(a)){\
		fprint(2, "twig internal error %s\n", (str));\
		abort();\
	}}while(0)

/* these are user defined, so can be macros */
#ifndef mtValue
int		mtValue(NODEPTR);
#endif
#ifndef mtSetNodes
void		mtSetNodes(NODEPTR, int, NODEPTR);
#endif
#ifndef mtGetNodes
NODEPTR		mtGetNodes(NODEPTR, int);
#endif

/* made by twig proper */
NODEPTR		mtAction(int, __match **, skeleton *);
short		mtEvalCost(__match *, __match **, skeleton *);

/* stuff defined here */
void		_addmatches(int, skeleton *, __match *);
void		_closure(int, int, skeleton *);
void		_do(skeleton *, __match *, int);
void		_evalleaves(__match **);
__match		**_getleaves(__match *, skeleton *);
int		_machstep(short, short);
void		_match(void);
void		_matchinit(void);
void		_merge(skeleton *, skeleton *);
NODEPTR		_mtG(NODEPTR, ...);
skeleton	*_walk(skeleton *, int);
__match		*_allocmatch(void);
skeleton	*_allocskel(void);
__partial	*_allocpartial(void);
void		_freematch(__match *);
void		_freeskel(skeleton *);
void		_freepartial(__partial *);
void		_prskel(skeleton *, int);

int mtDebug = 0;
int treedebug = 0;
extern short mtStart;
extern short mtTable[], mtAccept[], mtMap[], mtPaths[], mtPathStart[];
#ifdef LINE_XREF
	extern short mtLine[];
#endif

#ifndef MPBLOCKSIZ
#define MPBLOCKSIZ 10000
#endif
__match *_mpblock[MPBLOCKSIZ], **_mpbtop;

/*
 * See sym.h in the preprocessor for details
 * Basically used to support eh $%n$ construct.
 */
__match **
_getleaves(__match *mp, skeleton *skp)
{
	skeleton *stack[MAXDEPTH];
	skeleton **stp = stack;
	short *sip = &mtPaths[mtPathStart[mp->tree]];
	__match **mmp = _mpbtop;

	__match **mmp2 = mmp;
	_mpbtop += *sip++ + 1;
	_twig_assert(_mpbtop <= &_mpblock[MPBLOCKSIZ], "match block overflow");

	for(;;)
		switch(*sip++){
		case ePUSH:
			*stp++ = skp;
			skp = skp->leftchild;
			break;
		case eNEXT:
			skp = skp->sibling;
			break;
		case eEVAL:
			mp = skp->succ[M_DETAG(*sip++)];
			_twig_assert(mp != 0, "bad eEVAL");
			*mmp++ = mp;
			break;
		case ePOP:
			skp = *--stp;
			break;
		case eSTOP:
			*mmp = 0;
			return mmp2;
		}
}

void
_do(skeleton *sp, __match *winner, int evalall)
{
	skeleton *skel;

	if(winner == 0) {
		_prskel(sp, 0);
		fprint(2, "no winner");
		return;
	}

	skel = winner->skel;
	if(winner->mode == xDEFER || evalall && winner->mode != xTOPDOWN)
		REORDER(winner->lleaves);
	skel->root = mtAction(winner->tree, winner->lleaves, sp);
	mtSetNodes(skel->parent, skel->nson, skel->root);
}

void
_evalleaves(__match **mpp)
{
	__match *mp;
	while(*mpp){
		mp = *mpp++;
		_do(mp->skel, mp, 1);
	}
}

skeleton *
_walk(skeleton *sp, int ostate)
{
	int state, nstate, nson;
	__partial *pp;
	__match *mp;
	skeleton *nsp, *lastchild = 0;
	NODEPTR son, root;

	root = sp->root;
	nson = 1;
	sp->mincost = INFINITY;
	state = _machstep(ostate, mtValue(root));

	while(son = mtGetNodes(root, nson)){
		nstate = _machstep(state, MV_BRANCH(nson));
		nsp = _allocskel();
		nsp->root = son;
		nsp->parent = root;
		nsp->nson = nson;
		_walk(nsp, nstate);
		if(COSTLESS(nsp->mincost, INFINITY)){
			_twig_assert(nsp->winner->mode == xREWRITE, "bad mode");
			if(mtDebug || treedebug) {
				fprint(2, "rewrite\n");
				_prskel(nsp, 0);
			}
			_do(nsp, nsp->winner, 0);
			_freeskel(nsp);
			continue;
		}
		_merge(sp, nsp);
		if(lastchild == 0)
			sp->leftchild = nsp;
		else
			lastchild->sibling = nsp;
		lastchild = nsp;
		nson++;
	}

	for(pp = sp->partial; pp < &sp->partial[sp->treecnt]; pp++)
		if(pp->bits & 01){
			mp = _allocmatch();
			mp->tree = pp->treeno;
			_addmatches(ostate, sp, mp);
		}
	if(son == 0 && nson == 1)
		_closure(state, ostate, sp);

	sp->rightchild = lastchild;
	if(root == 0){
		COST c;
		__match *win;
		int i;

		nsp = sp->leftchild;
		c = INFINITY;
		win = 0;
		for(i = 0; i < MAXLABELS; i++){
			mp = nsp->succ[i];
			if(mp != 0 && COSTLESS(mp->cost, c)){
				c = mp->cost;
				win = mp;
			}
		}
		if(mtDebug || treedebug)
			_prskel(nsp, 0);
		_do(nsp, win, 0);
	}
	if(mtDebug)
		_prskel(sp, 0);
	return sp;
}

static short _nodetab[MAXNDVAL], _labeltab[MAXLABELS];

/*
 * Convert the start state which has a large branching factor into
 * a index table.  This must be called before the matcher is used.
 */
void
_matchinit(void)
{
	short *sp;

	for(sp = _nodetab; sp < &_nodetab[MAXNDVAL]; sp++)
		*sp = HANG;
	for(sp = _labeltab; sp < &_labeltab[MAXLABELS]; sp++)
		*sp = HANG;
	sp = &mtTable[mtStart];
	_twig_assert(*sp == TABLE, "mtTable[mtStart]!=TABLE");
	for(++sp; *sp != -1; sp += 2){
		if(MI_NODE(*sp))
			_nodetab[M_DETAG(*sp)] = sp[1];
		else if(MI_LABEL(*sp))
			_labeltab[M_DETAG(*sp)] = sp[1];
	}
}

int
_machstep(short state, short input)
{
	short *stp = &mtTable[state];
	int start = 0;

	if(state == HANG)
		return input == MV_BRANCH(1) ? mtStart : HANG;
rescan:
	if(stp == &mtTable[mtStart]){
		if(MI_NODE(input))
			return _nodetab[M_DETAG(input)];
		if(MI_LABEL(input))
			return _labeltab[M_DETAG(input)];
	}
	
	for(;;){
		if(*stp == ACCEPT)
			stp += 2;
		if(*stp == TABLE){
			stp++;
			while(*stp != -1)
				if(input == *stp)
					return stp[1];
				else
					stp += 2;
			stp++;
		}
		if(*stp != FAIL){
			if(start)
				return HANG;
			else{
				stp = &mtTable[mtStart];
				start = 1;
				goto rescan;
			}
		}else{
			stp = &mtTable[stp[1]];
			goto rescan;
		}
	}
}

void
_addmatches(int ostate, skeleton *sp, __match *np)
{
	int label;
	int state;
	__match *mp;

        label = mtMap[np->tree];

	/*
	 * this is a very poor substitute for good design of the DFA.
	 * What we need is a special case that allows any label to be accepted
	 * by the start state but we don't want the start state to recognize
	 * them after a failure.
	 */
	state = _machstep(ostate, MV_LABEL(label));
	if(ostate != mtStart && state == HANG){
		_freematch(np);
		return;
	}

	np->lleaves = _getleaves(np, sp);
	np->skel = sp;
        if((np->mode = mtEvalCost(np, np->lleaves, sp)) == xABORT){
		_freematch(np);
		return;
	}

	if(mp = sp->succ[label]){
		if(!COSTLESS(np->cost, mp->cost)){
			_freematch(np);
			return;
		}
		_freematch(mp);
	}
	if(COSTLESS(np->cost, sp->mincost)){
		if(np->mode == xREWRITE){
			sp->mincost = np->cost;
			sp->winner = np;
		}else{
			sp->mincost = INFINITY;
			sp->winner = 0;
		}
	}
	sp->succ[label] = np;
	_closure(state, ostate, sp);
}

void
_closure(int state, int ostate, skeleton *skp)
{
	short *sp = &mtTable[state];
	__match *mp;

	if(state == HANG || *sp != ACCEPT)
		return;

	for(sp = &mtAccept[sp[1]]; *sp != -1; sp += 2)
		if(sp[1] == 0){
			mp = _allocmatch();	
			mp->tree = *sp;
			_addmatches(ostate, skp, mp);
		}else{
			__partial *pp;
			__partial *lim = &skp->partial[skp->treecnt];
			for(pp = skp->partial; pp < lim; pp++)
				if(pp->treeno == *sp)
					break;
			if(pp == lim){
				skp->treecnt++;
				pp->treeno = *sp;
				pp->bits = 1 << sp[1];
			}else
				pp->bits |= 1 << sp[1];
		}
}

void
_merge(skeleton *old, skeleton *new)
{
	__partial *op = old->partial, *np = new->partial;
	int nson = new->nson;
	__partial *lim = np + new->treecnt;
	if(nson == 1){
		old->treecnt = new->treecnt;
		for(; np < lim; op++, np++) {
			op->treeno = np->treeno;
			op->bits = np->bits / 2;
		}
	}else{
		__partial *newer = _allocpartial();
		__partial *newerp = newer;
		int cnt;
		lim = op + old->treecnt;
		for(cnt = new->treecnt; cnt-- ; np++){
			for(op = old->partial; op < lim; op++)
				if(op->treeno == np->treeno){
					newerp->treeno = op->treeno;
					newerp++->bits = op->bits & np->bits / 2;
					break;
				}
		}
		_freepartial(old->partial);
		old->partial = newer;
		old->treecnt = newerp-newer;
	}
}
 
/* memory management */

#define BLKF	100

typedef union __matchalloc{
	__match			it;
	union __matchalloc	*next;
}__matchalloc;
static __matchalloc *__matchfree = 0;
#ifdef CHECKMEM
staic int a_matches;
#endif

__match *
_allocmatch(void)
{
	static int count = 0;
	static __matchalloc *block = 0;
	__matchalloc *m;

	if(__matchfree){
		m = __matchfree;
		__matchfree = __matchfree->next;
	}else{
		if(count == 0){
			block = (void *)malloc(BLKF * sizeof *block);
			count = BLKF;
		}
		m = block++;
		count--;
	}
#ifdef CHECKMEM
		a_matches++;
#endif
	return (__match *)m;
}

void
_freematch(__match *mp)
{
	__matchalloc *m = (__matchalloc *)mp;
	m->next = __matchfree;
	__matchfree = m;
#ifdef CHECKMEM
	a_matches--;
#endif
}

typedef union __partalloc{
	__partial		it[MAXTREES];
	union __partalloc	*next;
}__partalloc;
static __partalloc *__partfree = 0;
#ifdef CHECKMEM
static int a_partials;
#endif

__partial *
_allocpartial(void)
{
	static int count = 0;
	static __partalloc *block = 0;
	__partalloc *p;

	if(__partfree != 0){
		p = __partfree;
		__partfree = __partfree->next;
	}else{
		if(count == 0){
			block = (void *)malloc(BLKF * sizeof *block);
			count = BLKF;
		}
		p = block++;
		count--;
	}
#ifdef CHECKMEM
		a_partials++;
#endif
	return (__partial *)p;
}

void
_freepartial(__partial *pp)
{
	__partalloc *p = (__partalloc *)pp;
	p->next = __partfree;
	__partfree = p;
#ifdef CHECKMEM
	a_partials--;
#endif
}

typedef union __skelalloc{
	skeleton		it;
	union __skelalloc	*next;
}__skelalloc;
static __skelalloc *__skelfree = 0;

skeleton *
_allocskel(void)
{
	static int count = 0;
	static __skelalloc *block = 0;
	__skelalloc *sf;
	skeleton *s;
	int i;

	if(__skelfree){
		sf = __skelfree;
		__skelfree = sf->next;
	}else{
		if(count == 0){
			block = (void *)malloc(BLKF * sizeof *block);
			count = BLKF;
		}
		sf = block++;
		count--;
	}
	s = (skeleton *)sf;
	s->sibling = 0;
	s->leftchild = 0;
	i = 0;
	while(i < MAXLABELS)
		s->succ[i++] = 0;
	s->treecnt = 0;
	s->partial = _allocpartial();
	return s;
}

void
_freeskel(skeleton *s)
{
	int i;
	__match *mp;
	__skelalloc *sf;

	if(s == 0)
		return;
	if(s->leftchild)
		_freeskel(s->leftchild);
	if(s->sibling)
		_freeskel(s->sibling);
	_freepartial(s->partial);
	i = 0;
	while(i < MAXLABELS)
		if(mp = s->succ[i++])
			_freematch(mp);
	sf = (__skelalloc *)s;
	sf->next = __skelfree;
	__skelfree = sf;
}

void
_match(void)
{
	skeleton *sp;
	sp = _allocskel();
	sp->root = 0;
	_mpbtop = _mpblock;
	_freeskel(_walk(sp, HANG));
#ifdef CHECKMEM
	_twig_assert(a_matches == 0, "__match memory leak");
	_twig_assert(a_partials == 0, "__partial memory leak");
#endif
}

NODEPTR
_mtG(NODEPTR root, ...)
{
	va_list a;
	int i;

	va_start(a, root);
	while((i = va_arg(a, int)) != -1)
		root = mtGetNodes(root, i);
	va_end(a);
	return root;
}

/* diagnostic routines */
void
_prskel(skeleton *skp, int lvl)
{
	int i;
	__match *mp;
	__partial *pp;
	if(skp==0)
		return;
	for(i = lvl; i > 0; i--)
		fprint(2, "  ");
	fprint(2, "###\n");
	for(i = lvl; i > 0; i--)
		fprint(2, "  ");
	for(i = 0; i < MAXLABELS; i++)
		if(mp = skp->succ[i])
#ifdef LINE_XREF
			fprint(2, "[%d<%d> %d]", mp->tree,
				mtLine[mp->tree], mp->cost);
#else
			fprint(2, "[%d %d]", mp->tree, mp->cost);
#endif
	fprint(2, "\n");
	for(i = lvl; i > 0; i--)
		fprint(2, "  ");
	for(i = 0, pp=skp->partial; i < skp->treecnt; i++, pp++)
#ifdef LINE_XREF
			fprint(2, "(%d<%d> %x)", pp->treeno, mtLine[pp->treeno],
				pp->bits);
#else
			fprint(2, "(%d %x)", pp->treeno, pp->bits);
#endif
	fprint(2, "\n");
	_prskel(skp->leftchild, lvl + 2);
	_prskel(skp->sibling, lvl);
}
