#! /usr/bin/env escript
%% -*- erlang -*-
%%! -pz ../bitcask
%%
%% Bitcask Torture Tests
%% 
%% Pre-populates a bitcask file with a known range of keys, the value for
%% each key stores the key itself and a sequence number for the write
%% (initially 1).
%%
%% The test itself consists of a writer continually rewriting all keys with
%% the next sequence number while other processes read, fold over and merge
%% the cask.
%%

-module(bctt).
-compile([export_all]).
-include_lib("bitcask/include/bitcask.hrl").
-record(state, {seq = 0,
                reader_reps=0,
                folder_reps=0,
                fold_keys_reps=0,
                merger_reps=0,
                duration,
                status_freq = 10000,
                cask = "bctt.bc",
                num_keys=16384,
                writer_pid,
                restart_writer=false,
                writers=1,
                readers=0,
                folders=0,
                foldkeys=0,
                mergers=0,
                max_file_size=10*1024*1024,
                open_timeout=undefined}).

%% Function to run from inside beam
test() ->
    test([]).
    
test(Opts) ->
    {ok, State} = process_args(Opts, #state{}),
    print_params(State),
    spawn(fun() -> do_test(State) end).

%% Escript version
main() ->
    main([]).

main(Args) ->
    io:format("PWD: ~p\n", [file:get_cwd()]),
    try
        case process_args([parse_arg(Arg) || Arg <- Args], #state{}) of
            {ok, State} ->
                print_params(State),
                ensure_bitcask(),
                ensure_deps(),
                do_test(State);
            syntax ->
                syntax(),
                halt(0)
        end
    catch
        _:Why ->
            io:format("Failed: ~p\n", [Why]),
            halt(1)
    end.

syntax() ->
    io:format("bctt [duration=Msecs]\n"
              "     [cask=CaskFile]\n"
              "     [num_keys=NumKeys]\n"
              "     [readers=NumReaders]\n"
              "     [folders=NumFolders]\n"
              "     [foldkeys=NumFoldKeys]\n"
              "     [mergers=NumMergers]\n"
              "     [max_file_size=Bytes]  # set 0 for defaults\n"
              "     [open_timeout=Secs]\n"
              "     [restart_writer=true|false]\n").


do_test(State0) ->
    erlang:process_flag(trap_exit, true),
    os:cmd("rm -rf " ++ State0#state.cask),

    %% Put a base of keys - each value has {Key, Seq}, starting from 1.
    io:format("\nInitial populate.\n"),
    State1 = start_writer(State0),
    kick_writer(State1),
    wait_for_writer(State1),

    %% Start continually rewriting the keys and optionally reading,
    %% folding and merging
    State = State1#state{seq = 1},
    kick_writer(State),
    start_readers(State, State#state.readers),
    start_folders(State, State#state.folders),
    start_fold_keys(State, State#state.foldkeys),
    start_mergers(State, State#state.mergers),
    schedule_status(State),
    case State#state.duration of 
        undefined ->
            self() ! stop;
        Duration ->
            timer:send_after(Duration, self(), stop)
    end,
    io:format("Starting test.\n"),
    EndState = restart_procs(State),
    stop_writer(EndState),
    wait_for_procs(EndState).   

restart_procs(State) ->
    receive
        stop ->
            io:format("Test ending...\n"),
            State;
        status ->
            io:format("Writer seq: ~p  Readers=~p Folders=~p FoldKeys=~p Merges=~p\n",
                      [State#state.seq, State#state.reader_reps, 
                       State#state.folder_reps, State#state.fold_keys_reps,
                       State#state.merger_reps]),
            schedule_status(State),
            restart_procs(State);
        {write_done, WriteSeq} ->
            true = (State#state.seq+1 =:= WriteSeq),
            NewState = State#state{seq = WriteSeq},
            case State#state.restart_writer of
                true ->
                    stop_writer(State);
                false ->
                    kick_writer(NewState)
            end,
            restart_procs(NewState);
        write_exit ->
            State1 = start_writer(State),
            kick_writer(State1),
            restart_procs(State1);
        merge_done ->
            start_mergers(State, 1),
            restart_procs(State#state{merger_reps = State#state.merger_reps+1});
        {read_done, _ReadSeq} ->
            start_readers(State, 1),
            restart_procs(State#state{reader_reps = State#state.reader_reps+1});
        {fold_done, _FoldSeq} ->
            start_folders(State, 1),
            restart_procs(State#state{folder_reps = State#state.folder_reps+1});
        {fold_keys_done, _FoldSeq} ->
            start_fold_keys(State, 1),
            restart_procs(State#state{fold_keys_reps = State#state.fold_keys_reps+1});
        {'EXIT', _From, normal} ->
            restart_procs(State);
        {'EXIT', From, Reason} ->
            io:format("Test process ~p died\n~p\n", [From, Reason]),
            State;
        Other ->
            io:format("Restart procs got unexpected message\n~p\n", [Other]),
            State
    end.

%% Wait for the initial writer to complete - the os:cmd call
%% can generate an EXIT message
wait_for_writer(State) ->
    WriterPid = State#state.writer_pid,
    receive
        {'EXIT', WriterPid, Why} ->
            erlang:error({initial_write_failed, Why});
        {'EXIT', _Pid, _Why} ->
            wait_for_writer(State);
        {write_done, 1} ->
            ok
    end.

wait_for_procs(#state{writers = 0, readers = 0, folders = 0, foldkeys = 0, mergers = 0}) ->
    io:format("Test complete\n");
wait_for_procs(State) ->
    receive
        status ->
            wait_for_procs(State);
        {read_done, _ReadSeq} ->
            wait_for_procs(State#state{readers = State#state.readers - 1});
        {fold_done, _FoldSeq} ->
            wait_for_procs(State#state{folders = State#state.folders - 1});
        {fold_keys_done, _FoldKeySeq} ->
            wait_for_procs(State#state{foldkeys = State#state.foldkeys - 1});
        merge_done ->
            wait_for_procs(State#state{mergers = State#state.mergers - 1});
        {write_done, _WriteSeq} ->
            wait_for_procs(State);
        write_exit ->
            wait_for_procs(State#state{writers = State#state.writers - 1});
        {'EXIT', _From, normal} ->
            wait_for_procs(State);
        {'EXIT', From, Reason} ->
            io:format("Test process ~p died\n~p\n", [From, Reason]);
        Other ->
            io:format("Wait for procs got unexpected message\n~p\n", [Other])
    end.

schedule_status(State) ->
    case State#state.status_freq of 
        undefined ->
            ok;
        StatusFreq ->
            timer:send_after(StatusFreq, self(), status)
    end.
            
start_writer(State) ->
    Caller = self(),
    Pid = spawn_link(fun() ->
                             Opts = writer_opts(State),
                             Ref = bitcask:open(State#state.cask, Opts),
                             write_proc(Ref,  Caller)
                     end),
%%X    io:format("Started writer pid ~p\n", [Pid]),
    State#state{writer_pid = Pid}.

writer_opts(State) ->
    lists:flatten([
                   [read_write], 
                   case State#state.max_file_size of
                       Size when is_integer(Size), Size > 0->
                           [{max_file_size, Size}];
                       _ ->
                           []
                   end,
                   case State#state.open_timeout of
                       OpenTimeout when is_integer(OpenTimeout) ->
                           [{open_timeout, OpenTimeout}];
                       _ ->
                           []
                   end]).

start_readers(_State, 0) ->
    ok;
start_readers(State, NumReaders) ->
    Caller = self(),
    spawn_link(fun() ->
                       read_proc(State#state.cask, State#state.num_keys,
                                 State#state.seq, Caller)
               end),
    start_readers(State, NumReaders - 1).

start_folders(_State, 0) ->
    ok;
start_folders(State, NumFolders) ->
    Caller = self(),
    spawn_link(fun() ->
                       fold_proc(State#state.cask, State#state.num_keys,
                                 State#state.seq, Caller)
               end),
    start_folders(State, NumFolders - 1).

start_fold_keys(_State, 0) ->
    ok;
start_fold_keys(State, NumFoldKeys) ->
    Caller = self(),
    spawn_link(fun() ->
                       fold_keys_proc(State#state.cask, State#state.num_keys,
                                      State#state.seq, Caller)
               end),
    start_fold_keys(State, NumFoldKeys - 1).

start_mergers(_State, 0) ->
    ok;
start_mergers(State, NumMergers) ->
    Caller = self(),
    spawn_link(fun() -> merge_proc(State#state.cask, Caller) end),
    start_mergers(State, NumMergers - 1).
    
kick_writer(State) ->
    State#state.writer_pid ! {start, State#state.seq + 1, State#state.num_keys}.

stop_writer(State) ->
%%X    io:format("Stopping writer ~p\n", [State#state.writer_pid]),
    MRef = erlang:monitor(process, State#state.writer_pid),
    State#state.writer_pid ! stop,
    receive
        {'DOWN', MRef, _, _, _} ->
%%X           io:format("Stopped writer ~p\n", [State#state.writer_pid]),
            ok
    after
        10000 ->
            erlang:error({writer_pid_timeout, State#state.writer_pid})
    end.
        
write_proc(Ref, Caller) ->
    receive
        stop ->
%%X            io:format("Writer ~p received stop request\n", [self()]),
            Caller ! write_exit;
        {start, Seq, NumKeys} ->
            write(Ref, NumKeys, Seq),
            Caller ! {write_done, Seq},
            write_proc(Ref, Caller)
    end.

write(_Ref, 0, _Seq) ->
    ok;
write(Ref, Key, Seq) ->
    bitcask:put(Ref, <<Key:32>>, term_to_binary({Key, Seq})),
    write(Ref, Key - 1, Seq).

read_proc(Cask, NumKeys, Seq, Caller) ->
    Ref = bitcask:open(Cask),
    read(Ref, NumKeys, Seq),
    bitcask:close(Ref),
    Caller ! {read_done, Seq}.

read(_Ref, 0, _Iter) ->
    ok;
read(Ref, Key, MinSeq) ->
    {ok, Bin} = bitcask:get(Ref, <<Key:32>>),
    {Key, Seq} = binary_to_term(Bin),
    true = (Seq >= MinSeq),
    read(Ref, Key - 1, MinSeq).
   
fold_proc(Cask, NumKeys, Seq, Caller) ->
    Ref = bitcask:open(Cask),
    fold(Ref, NumKeys, Seq),
    bitcask:close(Ref),
    Caller ! {fold_done, Seq}.    

fold(Ref, NumKeys, MinSeq) ->
    Folder = fun(<<Key:32>>, Bin, {MinSeq2, Keys}) ->
                     {Key, Seq} = binary_to_term(Bin),
                     true = (Seq >= MinSeq2),
                     {MinSeq2, [Key | Keys]}
             end,
    {MinSeq, FoldedKeys} = bitcask:fold(Ref, Folder, {MinSeq, []}),
    check_fold(1, NumKeys, lists:sort(FoldedKeys)).

fold_keys_proc(Cask, NumKeys, Seq, Caller) ->
    Ref = bitcask:open(Cask),
    fold_keys(Ref, NumKeys),
    bitcask:close(Ref),
    Caller ! {fold_keys_done, Seq}.    

fold_keys(Ref, NumKeys) ->
    Folder = fun(#bitcask_entry{key = <<Key:32>>}, Keys) ->
                     %io:format(user, "FoldedKey: ~p\n", [Key]),
                     [Key | Keys]
             end,
    FoldedKeys = bitcask:fold_keys(Ref, Folder, []),
    check_fold(1, NumKeys, lists:sort(FoldedKeys)).

check_fold(Key, MaxKey, []) when Key == MaxKey + 1 ->
    ok;
check_fold(Key, _MaxKey, []) ->
    io:format("Fold missing key ~p (stopping searching)\n", [Key]);
check_fold(Key, MaxKey, [Key | Rest]) ->
    check_fold(Key + 1, MaxKey, Rest);
check_fold(Key1, _MaxKey, [_Key2 | _Rest]) ->
    io:format("Fold missing key ~p (stopping searching)\n", [Key1]).
    


merge_proc(Cask, Caller) ->
    ok = bitcask:merge(Cask),
    Caller ! merge_done.


ensure_bitcask() ->
    case code:ensure_loaded(bitcask) of
        {module, bitcask} ->
            ok;
        _ ->
            {ok, Cwd} = file:get_cwd(),
            find_bitcask(filename:split(Cwd))
    end.

%% Look for bitcask.beam in Cwd and Cwd/ebin
find_bitcask(["/"]) ->
    erlang:error("Could not find bitcask\n");
find_bitcask(Cwd) ->
    case try_bitcask_dir(Cwd) of
         true ->
            ok;
        false ->
            case try_bitcask_dir(Cwd ++ ["ebin"]) of
                true ->
                    ok;
                false ->
                    find_bitcask(parent_dir(Cwd))
            end
    end.
         
try_bitcask_dir(Dir) ->
    CodeDir = filename:join(Dir),
    Beam = bitcask_beam(CodeDir),
    io:format("Looking for bitcask in \"~s\".\n", [CodeDir]),
    case filelib:is_regular(Beam) of
        true ->
            io:format("Adding bitcask dir \"~s\".\n",
                      [CodeDir]),
            code:add_pathz(CodeDir),
            {module, bitcask} = code:ensure_loaded(bitcask),
            true;
       _ ->
            false
    end.

ensure_deps() ->
    BitcaskBeam = code:where_is_file("bitcask.beam"),
    BitcaskDir = parent_dir(filename:split(filename:dirname(BitcaskBeam))),
    Pattern = filename:join(BitcaskDir) ++ "/deps/*/ebin",
    Deps = filelib:wildcard(Pattern),
    AddDepDir = fun(DepDir) ->
                        io:format("Adding dependency dir \"~s\".\n",
                                  [DepDir]),
                        code:add_pathz(DepDir)
                end,
    lists:foreach(AddDepDir, Deps).
                  
parent_dir([]) ->
    ["/"];
parent_dir(["/"]) ->
    ["/"];
parent_dir(Dirs) ->
    lists:reverse(tl(lists:reverse(Dirs))).

bitcask_beam(Cwd) ->
    filename:join(Cwd, ["bitcask" ++ code:objfile_extension()]).

process_args([], State) ->
    {ok, State};
process_args([Arg | Rest], State) ->
    case process_arg(Arg, State) of
        {ok, NewState} ->
            process_args(Rest, NewState);
        Reason ->
            Reason
    end.

process_arg({help, _}, _State) ->
    syntax;
process_arg({Name, Val}, State) when Name =:= duration;
                                     Name =:= status_freq;
                                     Name =:= num_keys;
                                     Name =:= readers;
                                     Name =:= folders;
                                     Name =:= foldkeys;
                                     Name =:= mergers;
                                     Name =:= max_file_size;
                                     Name =:= open_timeout ->
    case is_integer(Val) of
        true ->
            {ok, setelement(get_state_index(Name), State, Val)};
        false ->
            {ok, setelement(get_state_index(Name), State, list_to_integer(Val))}
    end;
process_arg({Name, Val}, State) when Name =:= restart_writer ->
    case is_atom(Val) of
        true ->
            {ok, setelement(get_state_index(Name), State, Val)};
        false ->
            {ok, setelement(get_state_index(Name), State, list_to_atom(Val))}
    end;
process_arg({Name, Val}, State) ->
    {ok, setelement(get_state_index(Name), State, Val)}.

parse_arg([$- | Arg]) -> % strip leading --'s
    parse_arg(Arg);
parse_arg(Arg) ->
    case string:tokens(Arg, "=") of
        [NameStr] ->
            Val = undefined;
        [NameStr, Val] ->
            ok;
        _ ->
            NameStr = Val = undefined,
            erlang:error({badarg, Arg})
    end,
    case catch list_to_existing_atom(NameStr) of
        {'EXIT', {badarg, _}} ->
            Name = undefined,
            erlang:error({badarg, NameStr});
        Name ->
            ok
    end,
    {Name, Val}.
    

get_state_index(Name) ->
    get_state_index(Name, 2, record_info(fields, state)). % first element is 'state'

get_state_index(Name, _Index, []) ->
    io:format("Cannot find index in state for ~p\n", [Name]),
    erlang:error({badarg, Name});
get_state_index(Name, Index, [Name | _Rest]) ->
    Index;
get_state_index(Name, Index, [_OtherName | Rest]) ->
    get_state_index(Name, Index + 1, Rest).


print_params(State) ->
    io:format("Bitcask Test\n"),
    io:format("duration:      ~s\n", [format_duration(State#state.duration)]),
    io:format("cask:          ~s\n", [State#state.cask]),
    io:format("num_keys:      ~b\n", [State#state.num_keys]),
    io:format("readers:       ~b\n", [State#state.readers]),
    io:format("folders:       ~b\n", [State#state.folders]),
    io:format("foldkeys:      ~b\n", [State#state.foldkeys]),
    io:format("mergers:       ~b\n", [State#state.mergers]),
    io:format("max_file_size: ~s\n", [format_max_file_size(State#state.max_file_size)]),
    io:format("open_timeout:  ~s\n", [format_timeout(State#state.open_timeout)]),
    io:format("\n").

format_duration(undefined) ->
    "once";
format_duration(Duration) ->
    io_lib:format("~p ms", [Duration]).

format_timeout(undefined) ->
    "default";
format_timeout(Secs) ->
    io_lib:format("~b s", [Secs]).
    
format_max_file_size(MaxFileSize) when is_integer(MaxFileSize), MaxFileSize > 0 ->
    io_lib:format("~b", [MaxFileSize]);
format_max_file_size(_) ->
    "default".
