// Malloc verification

if objtype != "mips" then
	error("malloc checker for MIPS\n");

defn go()
{
	new();

	xmalloc = malloc + 4;		// Delay slot filler may pull in first
	xfree = free + 0xc;		// instruction

	bpset(xmalloc);
	bpset(xfree);
	bpset(_exits+4);

	blocklist = {};			// List of allocated blocks
	framelist = {};			// Frames they were allocated from
	stacks = {};			// The stacks
	nstack = 0;

	while 1 do {
		cont();
		pc = *PC;
		if pc == xmalloc then
			gotmalloc();
		else
		if pc == xfree then
			gotfree();
		else {
			pstop(pid);
			return {};
		}
	}
}

defn stkcmp(s1, s2)			// return true if the stack traces are the same
{
	local f1;
	local f2;

	while s1 do {
		f1 = head s1;
		f2 = head s2;

		if f2 == {} || f1[0] != f2[0] || f1[1] != f2[1] then
			return 0;

		s1 = tail s1;
		s2 = tail s2;
	}
	if s2 != {} then
		return 0;

	return 1;
}

defn stkindex(s)
{
	local slist;
	local n;

	slist = stacks;
	n = 0;
	while slist do {
		if stkcmp(head slist, s) then
			return n;

		slist = tail slist;
		n = n+1;
	}

	stacks = append stacks, s;
	n = nstack;
	nstack = nstack+1;
	return n;
}

defn gotmalloc()
{
	local raddr;
	local stk;

	raddr = *R31;
	bpset(raddr);
	stk = strace(*PC, *SP, *R31);	// Save a copy of the stack up to the malloc
	cont();				// Continue until malloc returns
	bpdel(raddr);

	blocklist = append blocklist, *R1;
	framelist = append framelist, stkindex(stk);
}

defn pentry(s)
{
	local sno;
	local blst;
	local gotone;
	local size;

	gotone = 0;
	blst = blocklist;
	flst = framelist;

	while blst do {
		sno = head flst;
		if sno == s then {
			addr = fmt(head blst, 'X');
			size = 1<<*(addr-12);
//			print(addr, "of ", fmt(size, 'D'), "bytes\n");
			gotone = gotone + size;
		}
		blst = tail blst;
		flst = tail flst;
	}
	return gotone;
}

defn leak()
{
	local n;
	local sz;

	n = 0;
	while n < nstack do {
		sz = pentry(n);
		if sz then {
			print("Lost a total of ", fmt(sz, 'D'), "bytes from:\n");
			mstack(stacks[n]);
		}
		n = n+1;
	}
}

defn scanmem(start, end)
{
	local n;

	start = fmt(start, 'X');
	while start < end do {
		n = match(*start, blocklist);
		if n >= 0 then {
			blocklist = delete blocklist, n;
			framelist = delete framelist, n;
		}
		start++;
	}
}

defn refs()			// Delete blocks with references
{
	print("data...");
	scanmem(etext, edata);

	print("bss...");
	scanmem(edata, *bloc);

	print("stack...");
	scanmem(*SP, 0x80000000);

	print("\n");
	leak();	
}

defn mstack(stk)
{
	while stk do {
		frame = head stk;
		print("\t", fmt(frame[0], 'a'), "() ");
		print(pcfile(frame[0]), ":", pcline(frame[0]));
		print("called from ", fmt(frame[1], 'a'), " ");
		pfl(frame[1]);
		stk = tail stk;
	}
}

defn gotfree()
{
	local n;
	local addr;

	addr = *R2;
	if addr == 0 then
		return {};

	n = match(addr, blocklist);
	if n < 0 then {
		print("free(", addr, ") called with bad pointer:\n");
		stk();
		return {};
	}
	blocklist = delete blocklist, n;
	framelist = delete framelist, n;
}

defn stopped(pid)
{
	return {};
}

defn funcchk(addr)
{
	local raddr;

	new();

	addr = addr+4;

	bpset(addr);

	while 1 do {
		cont();
		pc = *PC;
		if pc == addr then {
			arenachk("entry");
			raddr = *R31;
			bpset(raddr);
			cont();
			arenachk("exit");
			bpdel(raddr);
		}
		else {
			pstop(pid);
			return {};
		}
	}

}

defn arenachk(when)
{
	local addr;
	local l;
	local f;

	l = 0;
	addr = arena+4;
	loop 0, 32 do
	{
		if *addr then
		{
			bsize = fmt(1<<l,'D');
			f = *addr;
			while f do
			{
				if *f != l then
					error("BAD SIZE: arena corrupted at "+when);
				if *(f+4) != 0 then
					error("BAD MAGIC: arena corrupted at "+when);
				f = *(f+8);
			}
		}
		addr = addr+4;
		l = l+1;
	}
}

defn arena()
{
	local addr;
	local nb;
	local l;
	local f;

	l = 0;
	addr = arena+4;
	loop 0, 32 do
	{
		nb = fmt(0,'D');
		if *addr then
		{
			bsize = fmt(1<<l,'D');
			f = *addr;
			while f do
			{
				nb = nb+bsize;
				f = *(f+8);
			}
			print(bsize, addr, "has ", nb, "bytes free\n");
		}
		addr = addr+4;
		l = l+1;
	}
}

print("/sys/lib/acid/alefmalloc");
