-
Notifications
You must be signed in to change notification settings - Fork 9
/
midi_to_df_conversion.py
439 lines (390 loc) · 16.7 KB
/
midi_to_df_conversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
import os
from pathlib import Path
from typing import List, Dict
import click
import numpy as np
import pandas as pd
from mido import MidiFile
from sklearn import preprocessing
from tqdm import tqdm
from midi_utility import get_note_tracks, get_midi_file_hash
from chord_identifier import chord_attributes
# TODO: parallelize this, so we can take advantage of multiple cores.
def midi_files_to_df(
midi_filepaths: List[Path], skip_suspicious: bool = True
) -> pd.DataFrame:
dfs = []
hashes_to_filenames: Dict[str, str] = {}
pbar = tqdm(midi_filepaths)
for midi_filepath in pbar:
pbar.set_description(f"midi_to_df_conversion converting {midi_filepath} to df")
midi_file = MidiFile(midi_filepath)
midi_file_hash = get_midi_file_hash(midi_file)
if midi_file_hash in hashes_to_filenames:
tqdm.write(
f"midi_to_df_conversion skipping {midi_filepath} since an identical file exists "
f"({hashes_to_filenames[midi_file_hash]})"
)
continue
hashes_to_filenames[midi_file_hash] = midi_filepath
try:
df = _midi_file_to_df(midi_file)
if skip_suspicious and len(df.velocity.unique()) < 25:
tqdm.write(
f"midi_to_df_conversion skipping {midi_filepath} since it had few unique velocity values"
)
continue
df["name"] = os.path.split(midi_file.filename)[-1]
df = _add_engineered_features(df)
assert not np.any(df.index.duplicated()), (midi_filepath, df)
# reduce size by downcasting float64 and int64 columns
for column in df:
if column == "velocity":
df[column] = pd.to_numeric(df[column], downcast="float")
elif df[column].dtype == "float64":
df[column] = pd.to_numeric(df[column], downcast="float")
elif df[column].dtype == "int64":
df[column] = pd.to_numeric(df[column], downcast="integer")
dfs.append(df)
# TODO: catch more specific exception
except Exception as e:
tqdm.write(
f"midi_to_df_conversion got exception converting midi to df: {e}"
)
raise e
processed_count = len(dfs)
total_count = len(midi_filepaths)
click.echo(
f"midi_to_df_conversion converted {processed_count} files out of {total_count} to dfs"
)
return pd.concat(dfs)
def _midi_file_to_df(midi_file) -> pd.DataFrame:
note_tracks = get_note_tracks(midi_file)
note_events = [
(track.index, note_event)
for track in note_tracks
for note_event in track.note_events
]
note_events.sort(key=lambda note_event: note_event[1].time)
song_duration = note_events[-1][1].time
result = []
currently_playing_notes = []
for track_index, event in note_events:
if event.type == "note_on" and event.velocity > 0:
# get interval after the last released note by getting that note and checking the difference between the
# pitch values
if len(result) > 0:
interval_from_last_released_pitch = event.note - result[-1][4]
else:
interval_from_last_released_pitch = 0
# get interval after the last pressed note in a similar manner
if len(currently_playing_notes) > 0:
interval_from_last_pressed_pitch = (
event.note - currently_playing_notes[-1][0]
)
else:
interval_from_last_pressed_pitch = interval_from_last_released_pitch
# get the average pitch of all notes currently being played
curr_pitches = [p for p, _, _ in currently_playing_notes] + [event.note]
average_pitch = np.mean(curr_pitches)
# add features denoting the quality of chord being played. that means there are six possible values for the
# "character":
#
# - is it minor?
# - is it major?
# - is it diminished?
# - is it augmented?
# - is it suspended?
# - or none of the above.
chord_attrs = chord_attributes(curr_pitches)
chord_character = (
chord_attrs[0]
if chord_attrs is not None and chord_attrs[0] is not None
else "none"
)
# and seven possible values for the number of notes:
#
# - is it a dyad?
# - is it a triad?
# - is it a seventh?
# - is it a ninth?
# - is it an eleventh?
# - is it a thirteenth?
# - or none of the above.
chord_size = (
chord_attrs[1]
if chord_attrs is not None and chord_attrs[1] is not None
else "none"
)
note_on_data = [
event.velocity,
event.time,
track_index,
event.index,
event.note,
str(event.note % 12),
event.note // 12,
average_pitch,
event.time / song_duration,
-(((event.time / song_duration) * 2 - 1) ** 2) + 1,
interval_from_last_pressed_pitch,
interval_from_last_released_pitch,
len(currently_playing_notes) + 1,
int(len(currently_playing_notes) == 0),
chord_character,
chord_size,
]
currently_playing_notes.append((event.note, event.time, note_on_data))
elif event.type == "note_off" or (
event.type == "note_on" and event.velocity == 0
):
if not (any(note == event.note for note, _, _ in currently_playing_notes)):
# note off-type event for a pitch that isn't being played
continue
note_on = _, note_on_time, note_on_data = next(
x for x in currently_playing_notes if x[0] == event.note
)
currently_playing_notes.remove(note_on)
sustain_duration = event.time - note_on_time
# if we get a note with a 0 sustain duration, use the duration of the previous note (if there is one)
if sustain_duration == 0:
if len(result) > 0:
sustain_duration = result[-1][16]
else:
tqdm.write(
f"midi_to_df_conversion warning: got first note with 0 duration; defaulting to 25"
)
sustain_duration = 25.0
# get the average pitch of all notes currently being played
curr_pitches = [p for p, _, _ in currently_playing_notes] + [event.note]
average_pitch = np.mean(curr_pitches)
note_off_data = [
sustain_duration,
len(currently_playing_notes),
average_pitch,
]
# add new row to result and sort all rows by note time (2nd column)
result.append(note_on_data + note_off_data)
result.sort(key=lambda row: row[1])
skipped_events = len(note_events) - len(result)
if skipped_events > 0:
tqdm.write(
f"midi_to_df_conversion warning: saw {skipped_events} note off events for pitches that hadn't been played"
)
df = pd.DataFrame(result)
df.columns = [
"velocity",
"time",
"midi_track_index",
"midi_event_index",
"pitch",
"pitch_class",
"octave",
"avg_pitch_pressed",
"nearness_to_end",
"nearness_to_midpoint",
"interval_from_pressed",
"interval_from_released",
"num_played_notes_pressed",
"follows_pause",
"chord_character_pressed",
"chord_size_pressed",
"sustain",
"num_played_notes_released",
"avg_pitch_released",
]
df["song_duration"] = song_duration
return df
def _add_engineered_features(
df: pd.DataFrame, with_extra_features: bool = False
) -> pd.DataFrame:
"""Takes a data frame representing one MIDI song and adds a bunch of
additional features to it.
"""
# NOTE: it's faster to create each column individually then merge them all together at the end. ("chord_character",
# "chord_size", "time_since_last_pressed" and "time_since_last_released" are however needed in the df, so we add
# those to the df directly.)
new_cols: Dict[str, pd.Series] = {}
# calculate "true" chord character and size by bunching all samples within 5 time units together and picking the
# chord character and size of the last of each group for all of them. this makes it so that, if a chord is played
# with not all notes perfectly at the same time, even the first notes here will get the information of the full
# chord (hopefully).
df["chord_character"] = df.groupby(
np.floor(df.time / 5) * 5
).chord_character_pressed.transform("last")
df["chord_size"] = df.groupby(
np.floor(df.time / 5) * 5
).chord_size_pressed.transform("last")
# get time elapsed since last note event(s)
df["time_since_last_pressed"] = (df.time - df.time.shift()).fillna(0)
df["time_since_last_released"] = (
df.time - (df.time.shift() + df.sustain.shift())
).fillna(0)
# get time elapsed since various further events. since some of these happen rather rarely (resulting in some very
# large values), we also normalize.
for cat in [
"pitch_class",
"octave",
"follows_pause",
"chord_character",
"chord_size",
]:
col_name = f"time_since_{cat}"
col = pd.Series(
preprocessing.scale(
(df.time - df.groupby(cat)["time"].shift()).fillna(0).values
)
)
new_cols[col_name] = col
new_cols[f"log_{col_name}"] = pd.Series(np.log(col + 1))
# add some abs cols
for col in ["interval_from_pressed", "interval_from_released"]:
base = new_cols[col] if col in new_cols else df[col]
new_cols[f"abs_{col}"] = np.abs(base)
# add some log cols
for col in [
"time_since_chord_character",
"time_since_chord_size",
"time_since_follows_pause",
"time_since_octave",
"time_since_pitch_class",
]:
base = new_cols[col] if col in new_cols else df[col]
new_cols[f"log_{col}"] = pd.Series(np.log10(np.abs(base) + 1))
for col in [
"sustain",
"time_since_last_pressed",
"time_since_last_released",
"abs_interval_from_pressed",
"abs_interval_from_released",
]:
base = new_cols[col] if col in new_cols else df[col]
new_cols[f"log_{col}"] = pd.Series(np.log(np.abs(base) + 1))
# calculate some simple moving averages
sma_aggs = {
"pitch": ["mean", "min", "max", "std"],
"log_sustain": ["mean", "min", "max", "std"],
"interval_from_pressed": ["mean", "min", "max", "std"],
"log_time_since_last_pressed": ["mean", "min", "max", "std"],
"log_time_since_follows_pause": ["mean", "min", "max", "std"],
}
sma_windows = [15, 30, 75]
for col, funcs in sma_aggs.items():
base = new_cols[col] if col in new_cols else df[col]
for window in sma_windows:
for func in funcs:
sma = base.rolling(window).agg(func).bfill()
new_cols[f"{col}_sma_{func}_{window}"] = sma
fwd_sma = base[::-1].rolling(window).agg(func).bfill()[::-1]
new_cols[f"{col}_fwd_sma_{func}_{window}"] = fwd_sma
if col != "follows_pause":
new_cols[f"{col}_sma_{func}_{window}_oscillator"] = base - sma
new_cols[f"{col}_fwd_sma_{func}_{window}_oscillator"] = (
base - fwd_sma
)
# add ichimoku indicators
for col in [
"pitch",
"log_sustain",
"interval_from_released",
"interval_from_pressed",
]:
base = new_cols[col] if col in new_cols else df[col]
tenkan_sen = (base.rolling(9).max() + base.rolling(9).min()).bfill() / 2.0
kijun_sen = (base.rolling(26).max() + base.rolling(26).min()).bfill() / 2.0
senkou_span_a = (tenkan_sen + kijun_sen) / 2.0
senkou_span_b = (base.rolling(52).max() + base.rolling(52).min()).bfill() / 2.0
new_cols[f"{col}_tenkan_sen"] = tenkan_sen
new_cols[f"{col}_kijun_sen"] = kijun_sen
new_cols[f"{col}_senkou_span_a"] = senkou_span_a
new_cols[f"{col}_senkou_span_b"] = senkou_span_b
new_cols[f"{col}_chikou_span"] = base.shift(26).bfill()
new_cols[f"{col}_cloud_is_green"] = senkou_span_a - senkou_span_b
new_cols[f"{col}_relative_to_tenkan_sen"] = base - tenkan_sen
new_cols[f"{col}_relative_to_kijun_sen"] = base - kijun_sen
new_cols[f"{col}_tenkan_sen_relative_to_kijun_sen"] = tenkan_sen - kijun_sen
new_cols[f"{col}_relative_to_chikou_span"] = base - base.shift(26).bfill()
new_cols[f"{col}_relative_to_cloud"] = (
base - (senkou_span_a + senkou_span_b) / 2.0
)
if with_extra_features:
# add percent change columns
for col in [
"pitch",
"log_sustain",
"num_played_notes_pressed",
"num_played_notes_released",
"interval_from_pressed",
"interval_from_released",
"log_time_since_last_pressed",
"log_time_since_last_released",
]:
base = new_cols[col] if col in new_cols else df[col]
if col == "pitch":
new_cols[f"{col}_pct_change"] = base.pct_change().fillna(0.0)
else:
new_cols[f"{col}_pct_change"] = pd.Series(
(np.abs(base) + 1.0).pct_change().fillna(0.0)
)
ewm_aggs = {
"pitch": ["mean", "std"],
"log_sustain": ["mean", "std"],
"num_played_notes_pressed": ["mean", "std"],
"interval_from_pressed": ["mean", "std"],
"log_abs_interval_from_released": ["mean", "std"],
"log_time_since_last_pressed": ["mean", "std"],
"log_time_since_follows_pause": ["mean", "std"],
}
for col, funcs in ewm_aggs.items():
base = new_cols[col] if col in new_cols else df[col]
for func in funcs:
for span in [10, 20, 50]:
new_cols[f"{col}_ewm_{func}_{span}"] = (
base.ewm(span=span).agg(func).bfill()
)
new_cols[f"{col}_fwd_ewm_{func}_{span}"] = (
base[::-1].ewm(span=span).agg(func).bfill()[::-1]
)
# actually macd uses ewms with spans 12 and 26 and a signal ewm with span 9. but 2x those works better.
macd = (
base.ewm(span=24).agg(func).bfill()
- base.ewm(span=52).agg(func).bfill()
)
new_cols[f"{col}_ewm_{func}_macd"] = macd
new_cols[f"{col}_ewm_{func}_macd_signal"] = (
base.ewm(span=18).agg(func).bfill() - macd
)
if with_extra_features:
# calculate lag values (just taking the values of the previous/next rows)
for col in ["octave", "follows_pause", "chord_character", "chord_size"]:
for i in range(1, 6):
new_cols[f"{col}_lag_{i}"] = (
df[col].shift(i).bfill().astype(df[col].dtype)
)
new_cols[f"{col}_fwd_lag_{i}"] = (
df[col][::-1].shift(i).bfill()[::-1].astype(df[col].dtype)
)
if with_extra_features:
# get some aggregate data of the song as a whole
aggregators = {
"pitch": ["sum", "mean", "min", "max", "std"],
"log_sustain": ["sum", "mean", "min", "max", "std"],
"octave": ["nunique"],
}
aggregated = df.agg(aggregators)
for col, funcs in aggregators.items():
for func in funcs:
new_cols[f"{col}_{func}"] = pd.Series([aggregated[col][func]] * len(df))
if with_extra_features:
# total number of notes in song
note_count = pd.Series([len(df)] * len(df))
new_cols["note_count"] = note_count
new_cols["note_count_adj_by_dur"] = note_count / df.song_duration[0]
for name, new_col in new_cols.items():
if not pd.api.types.is_numeric_dtype(new_col):
continue
assert not np.any(np.isnan(new_col)), (name, new_col)
assert np.all(np.isfinite(new_col)), (name, new_col)
return pd.concat(
[df] + [col.rename(name) for name, col in new_cols.items()], axis=1
)