%
%  fixed point arithmetic
%

%  an integer has 26 bits
%
%  bit 25..16: integer part
%  bit 15..0 : fraction part


-module(fixed).
-author('marc.vanwoerkom@fernuni-hagen.de').

-include("constants.hrl").

-export([mul/2, 
	 divide/2,
	 sin/1, 
	 cos/1]).

-export([max/2,
	 min/2,
	 sqrt/1,
	 asin/1,
	 atan/1,
	 atan2/2,
	 anglediff/2]).

-export([fixed_to_float/1,
	 float_to_fixed/1,
	 integer_to_fixed/1,
	 test/0]).


%  multiplication

mul(X, Y) when integer(X), integer(Y) ->
    (X * Y) bsr 16.


%  division

divide(X, Y) when integer(X), integer(Y) ->
    (X bsl 16) div Y.


%  sine function using fixed point arithmetic

sin(X) when integer(X) ->
    if
	X < 0 -> 
	    -sin(-X);
        
	X > ?pi_div_2_fix -> 
	    sin(?pi_fix - X);
	
	true ->
	    X2 = mul(X, X),
	    X3 = mul(X, X2),
	    X5 = mul(X3, X2),
 	    X7 = mul(X5, X2),
	    X - (X3 div 6) + (X5 div 120) - (X7 div 5040)
    end.
    

%  cosine using fixed point arithmetic

cos(X) when integer(X) ->
    X2 = X + ?pi_div_2_fix,
    if
	X2 > ?pi_fix -> 
	    sin(X2 - ?two_pi_fix);
	true -> 
	    sin(X2)
    end.


%  func from extended fixed point arithmetics page

max(X, Y) when integer(X), integer(Y) ->
    if 
	X < Y ->
	    Y;
	true ->
	    X
    end.


min(X, Y) when integer(X), integer(Y) ->
    if 
	X < Y ->
	    X;
	true ->
	    Y
    end.


sqrt(X) when integer(X) ->
    U = max(X, ?one_fix), 
    L = 0,
    sqrt(X, U, L).

sqrt(X, U, L) when integer(X), integer(U), integer(L) ->
    %io:format("sqrt(~w, ~w, ~w) .. ~n", [X, U, L]),
    if
	U =:= L ->
	    U;
	true ->
	    G = divide(U + L, ?two_fix),
	    GG = mul(G, G),
	    %io:format("G is ~w, GG is ~w~n", [G, GG]),
	    if
		GG >= X ->
		    sqrt(X, G, L);
		true ->
		    sqrt(X, U, G + 1)
	    end
    end.


asin(X0) ->    
    X = 0,
    X1 = X  - divide(sin(X)  - X0, cos(X)),
    X2 = X1 - divide(sin(X1) - X0, cos(X1)),
    X3 = X2 - divide(sin(X2) - X0, cos(X2)).


atan(K) ->
    if
	K < 0 ->
	    -atan(-K);
	K > ?one_fix ->
	    ?pi_div_2_fix - atan(divide(?one_fix, K));
	true ->
	    asin(divide(K, sqrt(?one_fix + mul(K, K))))
    end.


atan2(Y, X) ->
    if 
	X =:= 0 ->
	    if 
		Y > 0 ->
		    ?pi_div_2_fix;
		Y < 0 ->
		    -?pi_div_2_fix;
		true ->
		    0
	    end;
	true ->
	    if
		Y =:= 0 ->
		    if
			X > 0 ->
			    0;
			true ->
			    ?pi_fix
		    end;
		true ->
		    Flip = false,
		    Mirror = false,
		    if 
			X < 0 ->
			    Y2 = -Y,
			    X2 = -X,
			    Flip2 = true;
			true ->
			    Y2 = Y,
			    X2 = X,
			    Flip2 = Flip
		    end,
		    if 
			Y2 < 0 ->
			    Y3 = -Y2,
			    Mirror2 = true;
			true ->
			    Y3 = Y2,
			    Mirror2 = Mirror
		    end,
		    if 
			Y3 < X2 ->
			    Result = atan(divide(Y3, X2));
			true ->
			    Result = ?pi_div_2_fix - atan(divide(X2, Y3))
		    end,
		    if 
			Flip2 ->
			    Result2 = Result - ?pi_fix;
			true ->
			    Result2 = Result
		    end,
		    if 
			Mirror2 ->
			    Result3 = - Result2;
			true ->
			    Result3 = Result2
		    end
	    end
    end.


anglediff(U, V) ->
    Diff = U - V,
    if 
	Diff < -?pi_fix ->
	    Diff2 = Diff + ?two_pi_fix;
	Diff > ?pi_fix ->
	    Diff2 = Diff - ?two_pi_fix;
	true ->
	    Diff2 = Diff
    end.

    
%  conversion

fixed_to_float(X) when integer(X) ->
    X / (1 bsl 16).

float_to_fixed(X) when float(X) ->
    round(X * (1 bsl 16)).

integer_to_fixed(X) when integer(X) ->
    X * (1 bsl 16).


%  test stuff

print_fixed(Description, X) ->
    io:format("~s: fixed is ~w, float is ~w~n", [Description, X, fixed_to_float(X)]).

test() ->
    io:format("1_fix is ~w~n", [1 bsl 16]),
    io:format("2_fix is ~w~n", [2 bsl 16]),

    print_fixed("pi_div_2_fix", ?pi_div_2_fix),
    print_fixed("pi_fix", ?pi_fix),
    print_fixed("two_pi_fix", ?two_pi_fix),

    io:format("float pi to int: ~w~n", [float_to_fixed(math:pi())]),

    I = 1 bsl 25,
    II = mul(I, I),
    I2 = divide(II, I),
    io:format("~w * ~w = ~w~n", [I, I, II]),
    io:format("~w / ~w = ~w~n", [II, I, I2]),

    io:format("mul(40,-209) should be -1 and is ~w~n", [mul(40,-209)]),
    io:format("sin(pi/2) should be 65526 and is ~w~n", [sin(?pi_div_2_fix)]),

    Phi = 123456789,
    S = sin(Phi),
    C = cos(Phi),
    E = mul(S, S) + mul(C,C),
    io:format("E = ~w = ~w~n", [E, fixed_to_float(E)]),

    io:format("cos(-10450) should be 64701 and is ~w~n", [cos(-10450)]),
    io:format("sqrt(1234) should be 8997 and is ~w~n", [sqrt(1234)]),
    io:format("mul(8997,8997) is ~w, mul(8993, 8993) is ~w~n", [mul(8997,8997), mul(8993,8993)]),
    io:format("sqrt(123456) should be 89950 and is ~w~n", [sqrt(123456)]),
    io:format("mul(89950,89950) is ~w, mul(89949, 89949) is ~w~n", [mul(89950,89950), mul(89949,89949)]),
    io:format("asin(0) should be 0 and is ~w~n", [asin(0)]),
    io:format("asin(30000) should be 31160 and is ~w~n", [asin(30000)]), 
    io:format("atan(0) should be 0 and is ~w~n", [atan(0)]),
    io:format("atan(-100) should be -100 and is ~w~n", [atan(-100)]),
    io:format("atan(100000) should be 64926 and is ~w~n", [atan(100000)]),
    io:format("atan2(-2500, 123456) should be -1326 and is ~w~n", [atan2(-2500, 123456)]),

    io:format("fixed:test done.~n").