#include <stdio.h>
#include <string.h>
#include <sys/types.h>
#include <jvmti.h>
#include <tijmp.h>
#include <hash.h>
#include <walk_heap.h>
#include <tag_list.h>
#include <gc.h>
#include <tijmp_class_handler.h>

extern jvmtiEnv *jvmti;

jlong current_object_tag = -1;

typedef struct link_info link_info;
struct link_info {
    jlong tag_owner;
    jvmtiHeapReferenceKind reference_kind;
    jint index;  /* field index or array index. */
    link_info* next;
};

typedef struct link_header {
    jlong tag_owned;
    link_info* owners;
} link_header;

static link_header* new_link_header (jlong* tag) {
    link_header* lh;
    (*jvmti)->Allocate (jvmti, sizeof(*lh), (unsigned char**)&lh);
    lh->tag_owned = *tag;
    lh->owners = NULL;
    return lh;
}

typedef struct  {
    hashtab* htab;
    tag_list* tijmp_classes;
} back_link_control;

static jint JNICALL link_back (jvmtiHeapReferenceKind reference_kind, 
		      const jvmtiHeapReferenceInfo* reference_info, 
		      jlong class_tag, jlong referrer_class_tag, 
		      jlong size, jlong* tag, jlong* referrer_tag_ptr, 
		      jint length, void* user_data) {
    back_link_control* blc;
    link_header* lh;
    link_info* li;

    if (reference_kind != JVMTI_HEAP_REFERENCE_FIELD &&
	reference_kind != JVMTI_HEAP_REFERENCE_ARRAY_ELEMENT && 
	reference_kind != JVMTI_HEAP_REFERENCE_STATIC_FIELD)
	return JVMTI_VISIT_OBJECTS;
    if (*tag == 0)
	return JVMTI_VISIT_OBJECTS;
    if (*referrer_tag_ptr == 0)
	return JVMTI_VISIT_OBJECTS;

    blc = (back_link_control*)user_data;
    
    /* ignore references to and from tijmp classes. */
    if (is_tijmp_class (referrer_class_tag, blc->tijmp_classes))
	return 0;
    if (is_tijmp_class (class_tag, blc->tijmp_classes))
	return JVMTI_VISIT_OBJECTS;

    lh = (link_header*)jmphash_search (blc->htab, tag);
    if (lh == NULL) {
	lh = new_link_header (tag);
	jmphash_insert (blc->htab, tag, lh);
    }
    (*jvmti)->Allocate (jvmti, sizeof (*li), (unsigned char**)&li);    
    li->tag_owner = *referrer_tag_ptr;
    li->reference_kind = reference_kind;
    if (reference_kind == JVMTI_HEAP_REFERENCE_FIELD || 
	reference_kind == JVMTI_HEAP_REFERENCE_STATIC_FIELD) {
	li->index = reference_info->field.index;
    } else if (reference_kind == JVMTI_HEAP_REFERENCE_ARRAY_ELEMENT) {
	li->index = reference_info->array.index;
    }
    li->next = lh->owners;
    lh->owners = li;

    /* always visit all objects... */
    return JVMTI_VISIT_OBJECTS;
}

static size_t tag_hash_func (void* v, size_t size) {
    jlong* l = (jlong*)v;
    return ((size_t)*l) % size;
}

static int tag_compare (void* v1, void* v2) {
    jlong* l1;
    jlong* l2;
    l1 = (jlong*)v1;
    l2 = (jlong*)v2;
    return *l1 - *l2;
}

static void cleanup_htab_entries (void* hkey, void* helem, void* ignored) {
    link_header* lh;
    link_info* li;
    link_info* next;
    lh = (link_header*)helem;
    for (li = lh->owners; li != NULL; li = next) {
	next = li->next;
	(*jvmti)->Deallocate (jvmti, (unsigned char*)li);
    }
    (*jvmti)->Deallocate (jvmti, (unsigned char*)lh);    
}

typedef struct {
    JNIEnv* env;
    jclass map_cls;
    jmethodID m_put;
    jobject map;    
    jclass oih_cls;
    jmethodID m_init_oih;
    jmethodID m_add_owner;
    jclass long_cls;
    jmethodID m_init_long;
} aoi_helper;

static void add_owner_info (void* hkey, void* helem, void* data) {
    jobject oih;
    link_info* li;
    jobject owned;
    aoi_helper* h = (aoi_helper*)data;
    link_header* lh = (link_header*)helem;
    
    oih = (*h->env)->NewObject (h->env, h->oih_cls, 
				h->m_init_oih, lh->tag_owned);
    for (li = lh->owners; li != NULL; li = li->next) {
	(*h->env)->CallVoidMethod (h->env, oih, h->m_add_owner, li->tag_owner, 
				   li->reference_kind, li->index);
    }
    owned = (*h->env)->NewObject (h->env, h->long_cls, h->m_init_long, 
				  lh->tag_owned);
    (*h->env)->CallObjectMethod (h->env, h->map, h->m_put, owned, oih);
}

static jobject build_java_objects (JNIEnv* env, hashtab* htab) {
    jmethodID m_init;
    jobject map;
    aoi_helper h;
    char* put_sign = 
	"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;";

    h.env = env;
    h.map_cls = (*env)->FindClass (env, "java/util/HashMap");
    m_init = (*env)->GetMethodID (env, h.map_cls, "<init>", "(I)V");
    map = (*env)->NewObject (env, h.map_cls, m_init, jmphash_size (htab));
    h.map = map;
    h.m_put = (*env)->GetMethodID (env, h.map_cls, "put", put_sign);

    h.oih_cls = (*env)->FindClass (env, "tijmp/OwnerInfoHeader");
    h.m_init_oih = (*env)->GetMethodID (env, h.oih_cls, "<init>", "(J)V");
    h.m_add_owner = (*env)->GetMethodID (env, h.oih_cls, "addOwner", "(JBI)V");

    h.long_cls = (*env)->FindClass (env, "java/lang/Long");
    h.m_init_long = (*env)->GetMethodID (env, h.long_cls, "<init>", "(J)V");

    jmphash_for_each (htab, add_owner_info, &h);
    return map;
}

static jlongArray build_start_object_array (JNIEnv* env, tag_list* so) {
    jlongArray la;
    if (so->next_pos < 0)
	return NULL;
    la = (*env)->NewLongArray (env, so->next_pos);
    (*env)->SetLongArrayRegion (env, la, 0, so->next_pos, so->tags);
    return la;
}

static jint JNICALL tag_untagged_objects (jlong class_tag, jlong size, 
				  jlong* tag, jint length, 
				  void* user_data) {
    tag_list* so = (tag_list*)user_data;
    if (*tag == 0)
	*tag = current_object_tag--;    
    if (class_tag == so->clz_tag)
	add_tag (so, *tag);
    return JVMTI_VISIT_OBJECTS;
}

static void report_owner_info (JNIEnv* env, jobject map, 
			       jlongArray start_objets) {
    jclass cls;
    jmethodID m_own;
    char* sign = "(Ljava/util/Map;[J)V";
    cls = (*env)->FindClass (env, "tijmp/TIJMPController");    
    m_own = (*env)->GetStaticMethodID (env, cls, "owners", sign);
    if (m_own != NULL)
	(*env)->CallStaticVoidMethod (env, cls, m_own, map, start_objets);
}

void find_and_show_owners (JNIEnv* env, jclass clz) {
    jvmtiError err;
    jvmtiHeapCallbacks callbacks;
    hashtab* htab;
    jobject ownerinfomap;
    jlong nanos_start;
    jlong nanos_end;
    jlongArray start_objects;
    tag_list so;
    tag_list tijmp_classes;
    jint class_count;
    jclass* cp = NULL;
    jclass** classes = &cp;
    back_link_control blc;
    
    /* force gc to remove all garbage. */
    force_gc ();

    (*jvmti)->GetTime (jvmti, &nanos_start);
    /* tag classes */
    tag_classes (env, &class_count, classes);
    (*jvmti)->Deallocate (jvmti, (unsigned char*)classes[0]);
    setup_tag_list (env, &tijmp_classes, clz);
    find_tijmp_classes (env, &tijmp_classes);
    
    /* tag all other objects */
    setup_tag_list (env, &so, clz);
    callbacks.heap_iteration_callback = tag_untagged_objects; 
    callbacks.heap_reference_callback = 0;
    callbacks.primitive_field_callback = 0;
    callbacks.array_primitive_value_callback = 0;
    callbacks.string_primitive_value_callback = 0;
    err = (*jvmti)->IterateThroughHeap (jvmti, 0, NULL, &callbacks, &so);
    if (err != JVMTI_ERROR_NONE)
	handle_global_error (err);

    fprintf (stderr, "linking back\n");
    /* create link back hash */
    htab = jmphash_new (100000, tag_hash_func, tag_compare);
    blc.htab = htab;
    blc.tijmp_classes = &tijmp_classes;
    callbacks.heap_iteration_callback = 0;
    callbacks.heap_reference_callback = link_back;
    err = (*jvmti)->FollowReferences (jvmti, 0, NULL, NULL, &callbacks, &blc);
    if (err != JVMTI_ERROR_NONE)
	handle_global_error (err);

    ownerinfomap = build_java_objects (env, htab);
    start_objects = build_start_object_array (env, &so);

    (*jvmti)->GetTime(jvmti, &nanos_end);
    nanos_end -= nanos_start;
    fprintf (stdout, "building back links took: %ld nanos\n", nanos_end);

    report_owner_info (env, ownerinfomap, start_objects);

    cleanup_tag_list (&so);
    jmphash_for_each (htab, cleanup_htab_entries, NULL);
    jmphash_free (htab);
}

static jsize find_pos (jsize len, jlong* tags, jlong tag) {
    jsize i;
    for (i = 0; i < len; i++)
	if (tags[i] == tag)
	    return i;
    return -1;
}

jobjectArray get_objects_for_tags (JNIEnv* env, jlongArray la) {
    jlong* tags;
    jint count;
    jsize len;
    jclass cls;
    jobject* objects;
    jlong* obj_tags;
    jobjectArray oa;
    jsize i;

    len = (*env)->GetArrayLength (env, la);
    tags = (*env)->GetLongArrayElements (env, la, NULL);
    (*jvmti)->GetObjectsWithTags (jvmti, len, tags, &count,
				  &objects, &obj_tags);
    
    cls = (*env)->FindClass (env, "java/lang/Object");
    oa = (*env)->NewObjectArray (env, len, cls, NULL);
    for (i = 0; i < count; i++) {
	jsize pos = find_pos (len, tags, obj_tags[i]);
	if (pos > -1)
	    (*env)->SetObjectArrayElement (env, oa, pos, objects[i]);
    }
    (*env)->ReleaseLongArrayElements (env, la, tags, JNI_ABORT);
    return oa;
}
