1#include <ruby.h>
2#include <ruby/st.h>
3
4static void
5numhash_free(void *ptr)
6{
7    if (ptr) st_free_table(ptr);
8}
9
10static VALUE
11numhash_alloc(VALUE klass)
12{
13    return Data_Wrap_Struct(klass, 0, numhash_free, 0);
14}
15
16static VALUE
17numhash_init(VALUE self)
18{
19    st_table *tbl = (st_table *)DATA_PTR(self);
20    if (tbl) st_free_table(tbl);
21    DATA_PTR(self) = st_init_numtable();
22    return self;
23}
24
25static VALUE
26numhash_aref(VALUE self, VALUE key)
27{
28    st_data_t data;
29    if (!SPECIAL_CONST_P(key)) rb_raise(rb_eArgError, "not a special const");
30    if (st_lookup((st_table *)DATA_PTR(self), (st_data_t)key, &data))
31	return (VALUE)data;
32    return Qnil;
33}
34
35static VALUE
36numhash_aset(VALUE self, VALUE key, VALUE data)
37{
38    if (!SPECIAL_CONST_P(key)) rb_raise(rb_eArgError, "not a special const");
39    if (!SPECIAL_CONST_P(data)) rb_raise(rb_eArgError, "not a special const");
40    st_insert((st_table *)DATA_PTR(self), (st_data_t)key, (st_data_t)data);
41    return self;
42}
43
44static int
45numhash_i(st_data_t key, st_data_t value, st_data_t arg)
46{
47    VALUE ret;
48    ret = rb_yield_values(3, (VALUE)key, (VALUE)value, (VALUE)arg);
49    if (ret == Qtrue) return ST_CHECK;
50    return ST_CONTINUE;
51}
52
53static VALUE
54numhash_each(VALUE self)
55{
56    st_table *table = DATA_PTR(self);
57    st_data_t data = (st_data_t)self;
58    return st_foreach_check(table, numhash_i, data, data) ? Qtrue : Qfalse;
59}
60
61static int
62update_func(st_data_t *key, st_data_t *value, st_data_t arg, int existing)
63{
64    VALUE ret = rb_yield_values(existing ? 2 : 1, (VALUE)*key, (VALUE)*value);
65    switch (ret) {
66      case Qfalse:
67	return ST_STOP;
68      case Qnil:
69	return ST_DELETE;
70      default:
71	*value = ret;
72	return ST_CONTINUE;
73    }
74}
75
76static VALUE
77numhash_update(VALUE self, VALUE key)
78{
79    if (st_update((st_table *)DATA_PTR(self), (st_data_t)key, update_func, 0))
80	return Qtrue;
81    else
82	return Qfalse;
83}
84
85#if SIZEOF_LONG == SIZEOF_VOIDP
86# define ST2NUM(x) ULONG2NUM(x)
87#elif SIZEOF_LONG_LONG == SIZEOF_VOIDP
88# define ST2NUM(x) ULL2NUM(x)
89#endif
90
91static VALUE
92numhash_size(VALUE self)
93{
94    return ST2NUM(((st_table *)DATA_PTR(self))->num_entries);
95}
96
97static VALUE
98numhash_delete_safe(VALUE self, VALUE key)
99{
100    st_data_t val, k = (st_data_t)key;
101    if (st_delete_safe((st_table *)DATA_PTR(self), &k, &val, (st_data_t)self)) {
102	return val;
103    }
104    return Qnil;
105}
106
107void
108Init_numhash(void)
109{
110    VALUE st = rb_define_class_under(rb_define_module("Bug"), "StNumHash", rb_cData);
111    rb_define_alloc_func(st, numhash_alloc);
112    rb_define_method(st, "initialize", numhash_init, 0);
113    rb_define_method(st, "[]", numhash_aref, 1);
114    rb_define_method(st, "[]=", numhash_aset, 2);
115    rb_define_method(st, "each", numhash_each, 0);
116    rb_define_method(st, "update", numhash_update, 1);
117    rb_define_method(st, "size", numhash_size, 0);
118    rb_define_method(st, "delete_safe", numhash_delete_safe, 1);
119}
120
121