-- Copyright 2006-2008 Deutsches Forschungszentrum fuer Kuenstliche Intelligenz 
-- or its licensors, as applicable.
-- 
-- You may not use this file except under the terms of the accompanying license.
-- 
-- Licensed under the Apache License, Version 2.0 (the "License"); you
-- may not use this file except in compliance with the License. You may
-- obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
-- 
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- 
-- Project: ocroscript
-- File: check-train-valid-bpnet-feature.lua
-- Purpose: train and test a backpropagation net on isolated characters, 
--          with few features and wider feature map
-- Responsible: rangoni
-- Reviewer: 
-- Primary Repository: 
-- Web Sites: www.iupr.org, www.dfki.de, www.ocropus.org

require 'lib.util'
require 'lib.datasets'
import_all(ocr)

function load_train_data(dataset_path)
        local dataset_dirname = path.dirname(dataset_path)
	train_data = {}
	for i in io.lines(dataset_path) do
	    local image_path, text_path = i:match("^([^%s]+)%s*([^%s]+)%s*$")
	    if image_path and text_path then
	        image_path = path.join(dataset_dirname, image_path)
	        text_path =  path.join(dataset_dirname, text_path)
	        local image = bytearray()
	        iulib.read_image_gray(image, image_path)
	        local text = datasets.read_transcript(text_path)
	        if (#text ~= 1) then
	        	print "transcription must contain only one character"
	        	assert(#text == 1) -- transcript is actually an array of nustrings
	        end
	        text = text[1] -- but there should be only one item in the array
	        table.insert(train_data, {image = image, text = text})
	    end
	end
	return train_data
end

function train(train_data,result_path,cc)
	cc:startTraining()
    for key, i in pairs(train_data) do
    	cc:addTrainingChar(i.image, i.text)
    end
	cc:finishTraining()
	cc:save(result_path)
end

function rec_chars(test_data, cc)
	local total = 0
	local errors = 0
	for key, i in pairs(test_data) do
	    local s = nustring.nustring()
	    local image = bytearray:new()
		narray.copy(image, i.image)
		for d = 1, 4 do
			degrade(image)
		end
	    cc:setImage(image)
		cc:best(s)
		total = total + 1
		if (s:utf8() ~= i.text:utf8()) then
			print(string.format("out:%s    target:%s", s:utf8(), i.text:utf8()))
			errors = errors + 1
		end
		image:delete()
	end
	print(string.format("error: %d/%d  %g%%",errors,total,errors*100./total))
	return errors*100./total
end

if #arg < 2 then
    print("usage: ... <basename_list> <output_log>")
    os.exit(1)
end

local features = "110100000"
local dim_x = 15
local dim_y = 15
local dataset_path = "../data/digits_ocropus/list"
local dataset_path = arg[1]
local result_path = "bpnet.classifier"

local database = load_train_data(dataset_path)

print(string.format("features: %s", features))
print(string.format("features size: %d %d", dim_x, dim_y))

local cc_train = make_BpnetCharacterClassifier(features,dim_x,dim_y) 
cc_train:set("nhidden", 100)
cc_train:set("epochs", 5)
cc_train:set("learningrate", 0.2)
cc_train:set("testportion", 0.2)
cc_train:set("normalize", 1)
cc_train:set("shuffle", 1)

train(database,result_path,cc_train)

local cc_test = make_BpnetCharacterClassifier()
cc_test:load(result_path)

local res = rec_chars(train_data, cc_test)

local log = util.secure_open(arg[2], "w")	
log:write(string.format("%g\n", res))

