-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
64 lines (51 loc) · 2.32 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from data import get_spam_train_data, get_spam_valid_data, get_college_data
from tree import c45, print_tree, get_entropy, split_data, find_best_threshold, find_best_threshold_fast, find_best_split, submission, accuracy
from unittest import TestCase
import unittest
class TreeTest(unittest.TestCase):
def setUp(self):
self.data = get_college_data()
# def test_entropy(self):
# self.assertAlmostEqual(get_entropy(self.data), 0.9709505944546686)
# def test_split(self):
# left, right = split_data(self.data, 0, 25)
# for point in left:
# self.assertLess(point.values[0], 25)
# self.assertEqual(len(left), 3)
# for point in right:
# self.assertGreaterEqual(point.values[0], 25)
# self.assertEqual(len(right), 7)
# def test_threshold(self):
# gain, thresh = find_best_threshold(self.data, 1)
# self.assertAlmostEqual(gain, 0.321928094887)
# self.assertEqual(thresh, 38000)
# def test_threshold_fast(self):
# gain, thresh = find_best_threshold_fast(self.data, 1)
# self.assertAlmostEqual(gain, 0.321928094887)
# self.assertEqual(thresh, 38000)
# def test_best_split(self):
# feature, thresh = find_best_split(self.data)
# self.assertEqual(feature, 1)
# self.assertEqual(thresh, 38000)
# left, right = split_data(self.data, feature, thresh)
# feature, thresh = find_best_split(left)
# self.assertEqual(feature, None)
# self.assertEqual(thresh, None)
# feature, thresh = find_best_split(right)
# self.assertEqual(feature, 0)
# self.assertEqual(thresh, 43)
# print find_best_threshold_fast(self.data,0)
# print find_best_threshold_fast(self.data,1)
# print find_best_threshold(self.data,0)
# print find_best_threshold(self.data,1)
# @unittest.skip("Comment out this line when ready.")
def testsubmission(self):
train = get_spam_train_data()
valid = get_spam_valid_data()
preds = submission(train, valid)
acc = accuracy(valid, preds)
print
print "Your current accuracy is:", acc
self.assertGreater(acc, .75)
if __name__ == '__main__':
unittest.main()