// -*- C++ -*-

// Copyright 2006-2007 Deutsches Forschungszentrum fuer Kuenstliche Intelligenz 
// or its licensors, as applicable.
// Copyright 1995-2005 by Thomas M. Breuel
// 
// You may not use this file except under the terms of the accompanying license.
// 
// Licensed under the Apache License, Version 2.0 (the "License"); you
// may not use this file except in compliance with the License. You may
// obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// 
// Project:
// File: 
// Purpose: 
// Responsible: tmb
// Reviewer: 
// Primary Repository: 
// Web Sites: 


// FIXME this should really work "word"-wise, centered on each word,
// otherwise it does the wrong thing for non-deskewed lines
// (it worked "word"-wise in the original version)

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <ctype.h>
#include "colib.h"
#include "imgio.h"
#include "imglib.h"
#include "segmentation.h"
#include "queue.h"
#include "ocr-utils.h"
#include "ocr-segmentations.h"
#include "logger.h"

using namespace ocropus;
using namespace iulib;
using namespace colib;

namespace {
    Logger log_main("lineseg.seg-cuts");
}

namespace hacked_labels {
    // FIXME get rid of this --tmb

    struct UnionFind {
        narray<int> p,rank;
        UnionFind(int max=10000) {
            p.resize(max);
            fill(p,-1);
            rank.resize(max);
            fill(rank,-1);
        }
        void make_set(int x) {
            if(x<0)
                throw "range error (UnionFind::make_set)";
            p[x] = x;
            rank[x] = 0;
        }
        void make_union(int x,int y) {
            if(x==y) return;
            link(find_set(x),find_set(y));
        }
        void link(int x,int y) {
            if(rank[x]>rank[y]) {
                p[y] = x;
            } else {
                p[x] = y;
                if(rank[x]==rank[y]) rank[y]++;
            }
        }
        int find_set(int x) {
            if(x<0)
                throw "range error (UnionFind::find_set)";
            if(p[x]<0)
                throw "trying to find a set that hasn't been created yet";
            if(x!=p[x]) p[x] = find_set(p[x]);
            return p[x];
        }
    };

    /// Label the connected components of an image.

    static int label_components_internal(intarray &image,
                                         intarray &guidance,
                                         bool four_connected) {
        int w = image.dim(0), h = image.dim(1);
        // We slice the image into columns and call make_set() 
        // for every continuous segment within each column.
        // Maximal number of segments per column is (h + 1) / 2.
        // We do it `w' times, so it's w * (h + 1) / 2.
        // We also need to add 1 because index 0 is not used, but counted.
        UnionFind uf(w * (h + 1) / 2 + 1);
        uf.make_set(0);
        int top = 1;
        for(int i=0;i<image.length1d();i++) image.at1d(i) = !!image.at1d(i);
        //for(int i=0;i<w;i++) {image(i,0) = 0; image(i,h-1) = 0;}
        //for(int j=0;j<h;j++) {image(0,j) = 0; image(w-1,j) = 0;}
        int range = four_connected?0:1;
        for(int i=0;i<w;i++) {
            int current_label = 0;
            for(int j=0;j<h;j++) {
                if(!image(i,j)) {
                    current_label = 0;
                    continue;
                }
                if(!current_label) {
                    current_label = top;
                    uf.make_set(top);
                    top++;
                }
                if(i) {
                    int guide = guidance(i,j);
                    for(int delta=-range;delta<=range;delta++) {
                        if(j + delta < 0 || j + delta >= h)
                            continue;
                        int adj_label = image.at(i-1,j+delta);
                        int adj_guide = guidance.at(i-1,j+delta);
                        if(adj_label && guide == adj_guide) {
                            current_label = uf.find_set(current_label);
                            adj_label = uf.find_set(adj_label);
                            if(current_label != adj_label) {
                                uf.make_union(current_label,adj_label);
                                current_label = uf.find_set(current_label);
                                adj_label = uf.find_set(adj_label);
                            }
                        }
                    }
                }
                image(i,j) = current_label;
            }
        }
        for(int i=0;i<image.length1d();i++) {
            if(!image.at1d(i)) continue;
            image.at1d(i) = uf.find_set(image.at1d(i));
        }
        return renumber_labels(image,1);
    }

    int label_components(intarray &image,bool four_connected=false) {
        intarray guidance;
        makelike(guidance, image);
        fill(guidance, 0);
        return label_components_internal(image, guidance, four_connected);
    }

    int label_components(intarray &image,
                         intarray &guide, 
                         bool four_connected=false) {
        int result;
        if(&image == &guide) {
            intarray tmp;
            copy(tmp, guide);
            result = label_components_internal(tmp, guide, four_connected);
            copy(image, tmp);
        } else {
            copy(image, guide);
            result = label_components_internal(image, guide, four_connected);
        }
        return result;
    }
}

static void local_min(floatarray &result,floatarray &data,int r) {
    int n = data.length();
    result.resize(n);
    for(int i=0;i<n;i++) {
        float lmin = data(i);
        for(int j=-r;j<=r;j++) {
            int k = i+j;
            if(unsigned(k)>=unsigned(n)) continue;
            if(data(k)>=lmin) continue;
            lmin = data(k);
        }
        result(i) = lmin;
    }
}

static void local_minima(intarray &result,floatarray &data,int r,float threshold) {
    int n = data.length();
    result.clear();
    floatarray lmin;
    local_min(lmin,data,r);
    for(int i=1;i<n-1;i++) {
        if(data(i)<=threshold && data(i)<=lmin(i) &&
           data(i)<=data(i-1) && data(i)<data(i+1)) {
            result.push(i);
        }
    }
}

struct CurvedCutSegmenter {
    int down_cost;
    int outside_diagonal_cost;
    int inside_diagonal_cost;
    int boundary_diagonal_cost;
    int inside_weight;
    int boundary_weight;
    int outside_weight;
    int min_range;
    float min_thresh;
    //virtual void params_for_chars() = 0;
    virtual void params_for_lines() = 0;
    virtual void find_allcuts() = 0;
    virtual void find_bestcuts() = 0;
    // virtual void relabel_image(bytearray &image) = 0;
    // virtual void relabel_image(intarray &image) = 0;
    virtual void set_image(bytearray &image) = 0;
    virtual ~CurvedCutSegmenter() {}
};

struct CurvedCutSegmenterImpl : CurvedCutSegmenter {
    // input
    intarray wimage;
    int where;

    // output
    intarray costs;
    intarray sources;
    int direction;
    int limit;

    intarray bestcuts;

    strbuf debug;
    intarray dimage;
    
    narray< narray <point> > cuts;
    floatarray cutcosts;

    CurvedCutSegmenterImpl() {
        //params_for_chars();
        params_for_lines();
        //params_from_hwrec_c();
    }

    void params_for_lines() {
        down_cost = 0;
        outside_diagonal_cost = 4;
        inside_diagonal_cost = 4;
        boundary_diagonal_cost = 0;
        outside_weight = 0;     
        boundary_weight = -1;   
        inside_weight = 4;      
        min_range = 3;
        //min_thresh = -2.0;
        min_thresh = 10.0;
    }

    // this function calculates the actual costs!
    void step(int x0,int x1,int y) {
        int w = wimage.dim(0),h = wimage.dim(1);
        Queue<point> queue(w*h);
        for(int i=x0;i<x1;i++) queue.enqueue(point(i,y));
        int low = 1;
        int high = wimage.dim(0)-1;
        
        while(!queue.empty()) {
            point p = queue.dequeue();
            int i = p.x, j = p.y;
            int cost = costs(i,j);
            int ncost = cost+wimage(i,j)+down_cost;
            if(costs(i,j+direction)>ncost) {
                costs(i,j+direction) = ncost;
                sources(i,j+direction) = i;
                if(j+direction!=limit) queue.enqueue(point(i,j+direction));
            }
            if(i>low) {
                if(wimage(i,j)==0)
                    ncost = cost+wimage(i,j)+outside_diagonal_cost;
                else if(wimage(i,j)>0)
                    ncost = cost+wimage(i,j)+inside_diagonal_cost;
                else if(wimage(i,j)<0)
                    ncost = cost+wimage(i,j)+boundary_diagonal_cost;
                if(costs(i-1,j+direction)>ncost) {
                    costs(i-1,j+direction) = ncost;
                    sources(i-1,j+direction) = i;
                    if(j+direction!=limit) queue.enqueue(point(i-1,j+direction));
                }
            }
            if(i<high) {
                if(wimage(i,j)==0)
                    ncost = cost+wimage(i,j)+outside_diagonal_cost;
                else if(wimage(i,j)>0)
                    ncost = cost+wimage(i,j)+inside_diagonal_cost;
                else if(wimage(i,j)<0)
                    ncost = cost+wimage(i,j)+boundary_diagonal_cost;
                if(costs(i+1,j+direction)>ncost) {
                    costs(i+1,j+direction) = ncost;
                    sources(i+1,j+direction) = i;
                    if(j+direction!=limit) queue.enqueue(point(i+1,j+direction));
                }
            }
        }
    }

    void find_allcuts() {
        int w = wimage.dim(0), h = wimage.dim(1);
        // initialize dimensions of cuts, costs etc
        cuts.resize(w);
        cutcosts.resize(w);
        costs.resize(w,h);
        sources.resize(w,h);

        fill(costs, 1000000000);
        for(int i=0;i<w;i++) costs(i,0) = 0;
        fill(sources, -1);
        limit = where;
        direction = 1;
        step(0,w,0);

        for(int x=0;x<w;x++) {
            cutcosts(x) = costs(x,where);
            cuts(x).clear();
            // bottom should probably be initialized with 2*where instead of
            // h, because where cannot be assumed to be h/2. In the most extreme
            // case, the cut could go through 2 pixels in each row
            narray<point> bottom;
            int i = x, j = where;
            while(j>=0) {
                bottom.push(point(i,j));
                i = sources(i,j);
                j--;
            }
            //cuts(x).resize(h);
            for(i=bottom.length()-1;i>=0;i--) cuts(x).push(bottom(i));
        }

        fill(costs, 1000000000);
        for(int i=0;i<w;i++) costs(i,h-1) = 0;
        fill(sources, -1);
        limit = where;
        direction = -1;
        step(0,w,h-1);

        for(int x=0;x<w;x++) {
            cutcosts(x) += costs(x,where);
            // top should probably be initialized with 2*(h-where) instead of
            // h, because where cannot be assumed to be h/2. In the most extreme
            // case, the cut could go through 2 pixels in each row
            narray<point> top;
            int i = x, j = where;
            while(j<h) {
                if(j>where) top.push(point(i,j));
                i = sources(i,j);
                j++;
            }
            for(i=0;i<top.length();i++) cuts(x).push(top(i));
        }

        // add costs for line "where"
        for(int x=0;x<w;x++) {
            cutcosts(x) += wimage(x,where);
        }

    }

    void find_bestcuts() {
        for(int i=0;i<cutcosts.length();i++) ext(dimage,i,int(cutcosts(i)+10)) = 0xff0000;
        for(int i=0;i<cutcosts.length();i++) ext(dimage,i,int(min_thresh+10)) = 0x800000;
        local_minima(bestcuts,cutcosts,min_range,min_thresh);
        for(int i=0;i<bestcuts.length();i++) {
            narray<point> &cut = cuts(bestcuts(i));
            for(int j=0;j<cut.length();j++) {
                point p = cut(j);
                ext(dimage,p.x,p.y) = 0x00ff00;
            }
        }
        if(debug) write_image_packed(debug,dimage);
    }

    void set_image(bytearray &image) {
        copy(dimage,image);
        int w = image.dim(0), h = image.dim(1);
        wimage.resize(w,h);
        fill(wimage, 0);
        float s1 = 0.0, sy = 0.0;
        for(int i=1;i<w;i++) for(int j=0;j<h;j++) {
            if(image(i,j)) { s1++; sy += j; }
            if(!image(i-1,j) && image(i,j)) wimage(i,j) = boundary_weight;
            else if(image(i,j)) wimage(i,j) = inside_weight;
            else wimage(i,j) = outside_weight;
        }
        where = int(sy/s1);
        for(int i=0;i<dimage.dim(0);i++) dimage(i,where) = 0x008000;
    }
};

// CurvedCutSegmenter *makeCurvedCutSegmenter() {
//  return new CurvedCutSegmenterImpl();
// }

class CurvedCutSegmenterToISegmentLineAdapter : public ISegmentLine {
    autoref<CurvedCutSegmenterImpl> segmenter;

    virtual const char *description() {
        return "curved cut segmenter";
    }

    virtual void set(const char *key,const char *value) {
        log_main.format("set parameter %s to sf", key, value);
        if(!strcmp(key,"debug"))
            segmenter->debug = value;
        else
            throw "unknown key";
    }

    virtual void set(const char *key,double value) {
        log_main.format("set parameter %s to %f", key, value);
        if(!strcmp(key,"down_cost"))
            segmenter->down_cost = (int)value;
        else if(!strcmp(key,"outside_diagonal_cost"))
            segmenter->outside_diagonal_cost = (int)value;
        else if(!strcmp(key,"inside_diagonal_cost"))
            segmenter->inside_diagonal_cost = (int)value;
        else if(!strcmp(key,"boundary_diagonal_cost"))
            segmenter->boundary_diagonal_cost = (int)value;
        else if(!strcmp(key,"outside_weight"))
            segmenter->outside_weight = (int)value;
        else if(!strcmp(key,"boundary_weight"))
            segmenter->boundary_weight = (int)value;
        else if(!strcmp(key,"inside_weight"))
            segmenter->inside_weight = (int)value;
        else if(!strcmp(key,"min_range"))
            segmenter->min_range = (int)value;
        else if(!strcmp(key,"min_thresh"))
            segmenter->min_thresh = value;
        else
            throw "unknown key";
    }

    virtual void charseg(intarray &result_segmentation,bytearray &orig_image) {
        log_main("segmenting", orig_image);
        enum {PADDING = 3};
        bytearray image;
        copy(image, orig_image);
        optional_check_background_is_lighter(image);
        binarize_simple(image);
        invert(image);
        pad_by(image, PADDING, PADDING);
        intarray segmentation;
        // pass image to segmenter
        segmenter->set_image(image);
        // find all cuts in the image
        segmenter->find_allcuts();
        // choose the best of all cuts
        segmenter->find_bestcuts();

        // FIXME clean this up --tmb Now, this kind of bothers me.  We
        // have a high level function that invokes a sequence of
        // well-encapsulated lower-level methods (above).  And then we
        // have a bunch of loops doing image processing. That really
        // doesnt look good.  At the very least, the image processing
        // code should get encapsulated into its own method.  But,
        // more generally, the notion of "here is a segmentation, but
        // some of the segments are disconnected, so those components
        // should be relabeled" is generally useful.  There is no
        // reason to stick it at the end of this code, it should
        // really be a separate function:
        // relabel_disconnected_segment_parts(segmentation) or
        // something like that All the code below should go into a
        // separate function.  That function should take a
        // segmentation as an input, identify all the segments that
        // consist of multiple, disconnected components, and relabel
        // them as appropriate

        // the method below has two problems:
        //  1) spurious components for thin lines (this could be 
        //     solved by using a more careful cutting strategy than the three pixels
        //     as used below)
        //  2) you can have more that one subimage between two cuts
        //     (this may actually be desired)
        // so let's try a different method: everything between two cuts is one component
        

        segmentation.resize(image.dim(0),image.dim(1));
        for(int i=0;i<image.dim(0);i++) for(int j=0;j<image.dim(1);j++)
            segmentation(i,j) = image(i,j)?1:0;

        /* 
         * that couldn't work - it can't separate chunks that are both
         * in one CC and between the same cuts, but disconnected
        // multiply connected components with 10000
        // so that we can combine it with the cut-information
        for(int i=0;i<image.dim(0);i++) for(int j=0;j<image.dim(1);j++)
            segmentation(i,j) *=10000; 
        */

        // now include the cut-information
        for(int r=0;r<segmenter->bestcuts.length();r++) {
            int c = segmenter->bestcuts(r);
            narray<point> &cut = segmenter->cuts(c);
            for(int y=0;y<image.dim(1);y++) {
                for(int x=cut(y).x;x<image.dim(0);x++) 
                    if(segmentation(x,y)) segmentation(x,y)++;
            }
        }
        hacked_labels::label_components(segmentation, segmentation);
        
        extract_subimage(result_segmentation,segmentation,PADDING,PADDING,
                         segmentation.dim(0)-PADDING,segmentation.dim(1)-PADDING);
        make_line_segmentation_white(result_segmentation);
        set_line_number(result_segmentation, 1);
        log_main("resulting segmentation", result_segmentation);
    }
};

ISegmentLine *ocropus::make_CurvedCutSegmenter() {
    return new CurvedCutSegmenterToISegmentLineAdapter();
}
