-
Notifications
You must be signed in to change notification settings - Fork 10
/
8. set.py
42 lines (38 loc) · 1.16 KB
/
8. set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from movies import training_set, training_labels, validation_set, validation_labels
def distance(movie1, movie2):
squared_difference = 0
for i in range(len(movie1)):
squared_difference += (movie1[i] - movie2[i]) ** 2
final_distance = squared_difference ** 0.5
return final_distance
def classify(unknown, dataset, labels, k):
distances = []
#Looping through all points in the dataset
for title in dataset:
movie = dataset[title]
distance_to_point = distance(movie, unknown)
#Adding the distance and point associated with that distance
distances.append([distance_to_point, title])
distances.sort()
#Taking only the k closest points
neighbors = distances[0:k]
num_good = 0
num_bad = 0
for neighbor in neighbors:
title = neighbor[1]
if labels[title] == 0:
num_bad += 1
elif labels[title] == 1:
num_good += 1
if num_good > num_bad:
return 1
else:
return 0
print(validation_set["Bee Movie"])
print(validation_labels["Bee Movie"])
guess=classify(validation_set["Bee Movie"], training_set, training_labels, 5)
print(guess)
if guess == validation_labels["Bee Movie"]:
print("Correct!")
else:
print("Wrong!")